File size: 8,976 Bytes
c7260b4
 
8b11e04
5a5e927
 
c7260b4
 
 
8b11e04
 
4f37b5c
c7260b4
4f37b5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7260b4
4f37b5c
 
 
 
 
 
 
 
 
c7260b4
4f37b5c
 
c7260b4
4f37b5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6bfaba
4f37b5c
 
 
 
 
 
 
c7260b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26c35dd
c7260b4
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
import streamlit as st
import torch
from transformers import T5ForConditionalGeneration, RobertaTokenizer
import re
import ast

# Load the fine-tuned model and tokenizer
model_repo_path = 'sabssag/Latex_to_Python_CodeT5-base'
model = T5ForConditionalGeneration.from_pretrained(model_repo_path, torch_dtype=torch.float16)
tokenizer = RobertaTokenizer.from_pretrained(model_repo_path)
model.eval()

# Fix unmatched brackets
def fix_unmatched_brackets(code):
    """
    Fix unmatched brackets in the code by ensuring that all opening brackets have corresponding closing brackets,
    and ensure that newline characters are handled correctly when adding missing brackets.
    """
    open_brackets = {'(': 0, '[': 0, '{': 0}
    close_brackets = {')': 0, ']': 0, '}': 0}
    bracket_pairs = {'(': ')', '[': ']', '{': '}'}

    stack = []
    new_code = ""

    # Iterate through the code to track unmatched brackets and their positions
    for i, char in enumerate(code):
        if char in open_brackets:
            stack.append(char)
            open_brackets[char] += 1
        elif char in close_brackets:
            if stack and bracket_pairs[stack[-1]] == char:
                stack.pop()  # Matching bracket
            else:
                # Unmatched closing bracket found, but we need to check if it's valid
                if stack:
                    # If we have an unmatched opening bracket, fix it by adding the correct closing bracket
                    new_code += bracket_pairs[stack.pop()]
                else:
                    # If no matching opening bracket, just skip adding the closing bracket
                    continue
        new_code += char

    # Append missing closing brackets at the end
    while stack:
        last_char = new_code[-1]
        # If the last character is a newline, remove it before appending the closing bracket
        if last_char == '\n':
            new_code = new_code[:-1]
        new_code += bracket_pairs[stack.pop()]

    return new_code
# Validate and correct bracket balance
def validate_bracket_balance(code):
    """
    Validates if brackets are balanced and fixes common issues.
    """
    stack = []
    bracket_map = {')': '(', ']': '[', '}': '{'}
    
    for i, char in enumerate(code):
        if char in bracket_map.values():
            stack.append(char)
        elif char in bracket_map:
            if stack and stack[-1] == bracket_map[char]:
                stack.pop()
            else:
                code = code[:i] + '#' + code[i+1:]  # Comment out the misaligned closing bracket
                break
    
    while stack:
        code += { '(': ')', '[': ']', '{': '}' }[stack.pop()]
    
    return code

# Add missing imports based on used functions
def add_missing_imports(code):
    """
    Detect missing sympy or numpy imports based on used functions in the code.
    Also fixes incorrect import statements like `from sympy import, pi`.
    """
    sympy_funcs = {
        "cot", "sqrt", "pi", "sin", "cos", "tan", "log", "Abs", "exp", 
        "factorial", "csc", "sec", "asin", "acos", "atan", "Eq", "symbols", "Function", "Derivative"
    }

    # Detect function calls and existing imports
    function_pattern = r'\b([a-zA-Z_][a-zA-Z0-9_]*)\b'
    used_functions = set(re.findall(function_pattern, code))

    # Match 'from sympy import' statements
    existing_imports = re.findall(r'from sympy import ([a-zA-Z_, ]+)', code)

    # Flatten the existing imports set by splitting any comma-separated imports
    existing_imports_set = {imp.strip() for ex_imp in existing_imports for imp in ex_imp.split(',')}

    # Find which sympy functions are required but not yet imported
    required_imports = used_functions.intersection(sympy_funcs) - existing_imports_set

    # If there are required imports, we will just add them on top of the existing imports
    if required_imports:
        # Consolidate all imports into one line, without adding duplicate imports
        import_statement = f"from sympy import {', '.join(sorted(existing_imports_set | required_imports))}\n"
        
        # Remove the current sympy imports with a consolidated import statement
        code = re.sub(r'from sympy import [a-zA-Z_, ]+\n', '', code)
        code = import_statement + code

    # Fully remove incorrect import statements (like `from sympy import, pi`)
    code = re.sub(r'from sympy import,\s*.*\n', '', code)

    # Add numpy import if necessary
    if "np." in code and "import numpy as np" not in code:
        code = "import numpy as np\n" + code

    return code
# Enhanced removal of evalf() calls, handling malformed cases
def remove_evalf(code):
    """
    Remove all occurrences of .evalf() from the code, including cases where it's misplaced or malformed.
    """
    # Remove evalf calls in a more comprehensive way
    code = re.sub(r'\.evalf\(\)', '', code)  # Regular evalf calls
    code = re.sub(r'\*evalf\(\)', '', code)  # Cases like `*evalf()`
    
    # Ensure parentheses remain balanced even after removing evalf()
    code = fix_unmatched_brackets(code)
    
    return code

def handle_sum_errors(code):
    """
    Detects and fixes cases where `sum()` is applied to non-iterable objects.
    """
    # Regex to detect invalid use of sum
    invalid_sum_pattern = r'sum\(([^()]+)\)'
    
    # Replace invalid sum usage with the content inside the sum (since it's non-iterable)
    code = re.sub(invalid_sum_pattern, r'\1', code)
    
    return code

def complete_try_catch_block(code):
    """
    Ensure that the try block in the code is followed by a valid except block.
    If missing, a generic except block will be added.
    """
    # Check if there's a 'try' block without an 'except' block
    if 'try:' in code and 'except' not in code:
        # Add a generic except block to catch any exceptions
        code = re.sub(r'try:', r'try:\n        pass\n    except Exception as e:\n        print(f"Error: {e}")', code)
    return code

import re

def remove_extra_variables_from_function(code):
    """
    Remove extra variables from the function definition list of arguments
    that are not used in the function body.
    """

    # Find the function definition
    match = re.search(r'def\s+([a-zA-Z_][a-zA-Z0-9_]*)\((.*?)\):', code)

    if match:
        func_name = match.group(1)
        arg_list = match.group(2).split(',')
        arg_list = [arg.strip() for arg in arg_list]  # Clean up spaces

        # Get the body of the function (everything after the definition)
        func_body = code.split(':', 1)[1]

        # Find which variables are actually used in the function body
        used_vars = set(re.findall(r'\b([a-zA-Z_][a-zA-Z0-9_]*)\b', func_body))

        # Filter out only the arguments that are actually used in the function body
        filtered_args = [arg for arg in arg_list if arg in used_vars]

        # Reconstruct the function definition with only the used arguments
        new_func_def = f"def {func_name}({', '.join(filtered_args)}):"

        # Replace the old function definition with the new one
        code = re.sub(r'def\s+[a-zA-Z_][a-zA-Z0-9_]*\s*\(.*?\):', new_func_def, code)

    return code

# Post-process the generated code
def post_process_code(code):
    code = fix_unmatched_brackets(code)
    code = validate_bracket_balance(code)
    code = add_missing_imports(code)
    code = remove_evalf(code)
    code = handle_sum_errors(code)
    code = complete_try_catch_block(code)
    code = remove_extra_variables_from_function(code)
    return code

# Generate the final code from LaTeX
def generate_code(latex_expression, max_length=512):
    inputs = tokenizer(f"Latex Expression: {latex_expression} Solution:", return_tensors="pt")
    outputs = model.generate(**inputs, max_length=max_length)
    generated_code = tokenizer.decode(outputs[0], skip_special_tokens=True)
    post_processed_code = post_process_code(generated_code)
    return post_processed_code




# Streamlit app layout
st.title("LaTeX to Python Code Generator")

# Define session state keys
if 'latex_expr' not in st.session_state:
    st.session_state.latex_expr = ""

# User input for LaTeX expression
latex_input = st.text_area("Enter the LaTeX Expression", value=st.session_state.latex_expr, height=150)

# Update session state with the new LaTeX expression
if st.button("Generate Code"):
    if latex_input:
        st.session_state.latex_expr = latex_input
        with st.spinner("Generating Python Code..."):
            try:
                # Correct function name here
                generated_code = generate_code(latex_expression=st.session_state.latex_expr)
                # Display the generated code
                st.subheader("Generated Python Code")
                st.code(generated_code, language='python')
            except Exception as e:
                st.error(f"Error during code generation: {e}")
    else:
        st.warning("Please enter a LaTeX expression to generate Python code.")