sabssag commited on
Commit
4f37b5c
·
verified ·
1 Parent(s): c7260b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +198 -8
app.py CHANGED
@@ -8,18 +8,208 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
8
  model_repo_path = 'sabssag/Latex_to_Python_CodeT5-base'
9
  model = AutoModelForSeq2SeqLM.from_pretrained(model_repo_path)
10
  tokenizer = AutoTokenizer.from_pretrained(model_repo_path)
 
11
 
12
- # Function to generate Python code from LaTeX expression
13
- def generate_code_from_latex(latex_expression, max_length=256):
14
- inputs = tokenizer(f"Latex Expression: {latex_expression} Solution:", return_tensors="pt").to(model.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- # Generate the output
17
- outputs = model.generate(**inputs, max_length=max_length)
 
 
 
 
 
 
 
18
 
19
- # Decode the output into Python code
20
- generated_code = tokenizer.decode(outputs[0], skip_special_tokens=True)
21
 
22
- return generated_code
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  # Streamlit app layout
25
  st.title("LaTeX to Python Code Generator")
 
8
  model_repo_path = 'sabssag/Latex_to_Python_CodeT5-base'
9
  model = AutoModelForSeq2SeqLM.from_pretrained(model_repo_path)
10
  tokenizer = AutoTokenizer.from_pretrained(model_repo_path)
11
+ model.eval()
12
 
13
+
14
+ import re
15
+ import ast
16
+ import torch
17
+
18
+ # Fix unmatched brackets
19
+ def fix_unmatched_brackets(code):
20
+ """
21
+ Fix unmatched brackets in the code by ensuring that all opening brackets have corresponding closing brackets,
22
+ and ensure that newline characters are handled correctly when adding missing brackets.
23
+ """
24
+ open_brackets = {'(': 0, '[': 0, '{': 0}
25
+ close_brackets = {')': 0, ']': 0, '}': 0}
26
+ bracket_pairs = {'(': ')', '[': ']', '{': '}'}
27
+
28
+ stack = []
29
+ new_code = ""
30
+
31
+ # Iterate through the code to track unmatched brackets and their positions
32
+ for i, char in enumerate(code):
33
+ if char in open_brackets:
34
+ stack.append(char)
35
+ open_brackets[char] += 1
36
+ elif char in close_brackets:
37
+ if stack and bracket_pairs[stack[-1]] == char:
38
+ stack.pop() # Matching bracket
39
+ else:
40
+ # Unmatched closing bracket found, but we need to check if it's valid
41
+ if stack:
42
+ # If we have an unmatched opening bracket, fix it by adding the correct closing bracket
43
+ new_code += bracket_pairs[stack.pop()]
44
+ else:
45
+ # If no matching opening bracket, just skip adding the closing bracket
46
+ continue
47
+ new_code += char
48
+
49
+ # Append missing closing brackets at the end
50
+ while stack:
51
+ last_char = new_code[-1]
52
+ # If the last character is a newline, remove it before appending the closing bracket
53
+ if last_char == '\n':
54
+ new_code = new_code[:-1]
55
+ new_code += bracket_pairs[stack.pop()]
56
+
57
+ return new_code
58
+ # Validate and correct bracket balance
59
+ def validate_bracket_balance(code):
60
+ """
61
+ Validates if brackets are balanced and fixes common issues.
62
+ """
63
+ stack = []
64
+ bracket_map = {')': '(', ']': '[', '}': '{'}
65
 
66
+ for i, char in enumerate(code):
67
+ if char in bracket_map.values():
68
+ stack.append(char)
69
+ elif char in bracket_map:
70
+ if stack and stack[-1] == bracket_map[char]:
71
+ stack.pop()
72
+ else:
73
+ code = code[:i] + '#' + code[i+1:] # Comment out the misaligned closing bracket
74
+ break
75
 
76
+ while stack:
77
+ code += { '(': ')', '[': ']', '{': '}' }[stack.pop()]
78
 
79
+ return code
80
+
81
+ # Add missing imports based on used functions
82
+ def add_missing_imports(code):
83
+ """
84
+ Detect missing sympy or numpy imports based on used functions in the code.
85
+ Also fixes incorrect import statements like `from sympy import, pi`.
86
+ """
87
+ sympy_funcs = {
88
+ "cot", "sqrt", "pi", "sin", "cos", "tan", "log", "Abs", "exp",
89
+ "factorial", "csc", "sec", "asin", "acos", "atan", "Eq", "symbols", "Function", "Derivative"
90
+ }
91
+
92
+ # Detect function calls and existing imports
93
+ function_pattern = r'\b([a-zA-Z_][a-zA-Z0-9_]*)\b'
94
+ used_functions = set(re.findall(function_pattern, code))
95
+
96
+ # Match 'from sympy import' statements
97
+ existing_imports = re.findall(r'from sympy import ([a-zA-Z_, ]+)', code)
98
+
99
+ # Flatten the existing imports set by splitting any comma-separated imports
100
+ existing_imports_set = {imp.strip() for ex_imp in existing_imports for imp in ex_imp.split(',')}
101
+
102
+ # Find which sympy functions are required but not yet imported
103
+ required_imports = used_functions.intersection(sympy_funcs) - existing_imports_set
104
+
105
+ # If there are required imports, we will just add them on top of the existing imports
106
+ if required_imports:
107
+ # Consolidate all imports into one line, without adding duplicate imports
108
+ import_statement = f"from sympy import {', '.join(sorted(existing_imports_set | required_imports))}\n"
109
+
110
+ # Remove the current sympy imports with a consolidated import statement
111
+ code = re.sub(r'from sympy import [a-zA-Z_, ]+\n', '', code)
112
+ code = import_statement + code
113
+
114
+ # Fully remove incorrect import statements (like `from sympy import, pi`)
115
+ code = re.sub(r'from sympy import,\s*.*\n', '', code)
116
+
117
+ # Add numpy import if necessary
118
+ if "np." in code and "import numpy as np" not in code:
119
+ code = "import numpy as np\n" + code
120
+
121
+ return code
122
+ # Enhanced removal of evalf() calls, handling malformed cases
123
+ def remove_evalf(code):
124
+ """
125
+ Remove all occurrences of .evalf() from the code, including cases where it's misplaced or malformed.
126
+ """
127
+ # Remove evalf calls in a more comprehensive way
128
+ code = re.sub(r'\.evalf\(\)', '', code) # Regular evalf calls
129
+ code = re.sub(r'\*evalf\(\)', '', code) # Cases like `*evalf()`
130
+
131
+ # Ensure parentheses remain balanced even after removing evalf()
132
+ code = fix_unmatched_brackets(code)
133
+
134
+ return code
135
+
136
+ def handle_sum_errors(code):
137
+ """
138
+ Detects and fixes cases where `sum()` is applied to non-iterable objects.
139
+ """
140
+ # Regex to detect invalid use of sum
141
+ invalid_sum_pattern = r'sum\(([^()]+)\)'
142
+
143
+ # Replace invalid sum usage with the content inside the sum (since it's non-iterable)
144
+ code = re.sub(invalid_sum_pattern, r'\1', code)
145
+
146
+ return code
147
+
148
+ def complete_try_catch_block(code):
149
+ """
150
+ Ensure that the try block in the code is followed by a valid except block.
151
+ If missing, a generic except block will be added.
152
+ """
153
+ # Check if there's a 'try' block without an 'except' block
154
+ if 'try:' in code and 'except' not in code:
155
+ # Add a generic except block to catch any exceptions
156
+ code = re.sub(r'try:', r'try:\n pass\n except Exception as e:\n print(f"Error: {e}")', code)
157
+ return code
158
+
159
+ import re
160
+
161
+ def remove_extra_variables_from_function(code):
162
+ """
163
+ Remove extra variables from the function definition list of arguments
164
+ that are not used in the function body.
165
+ """
166
+
167
+ # Find the function definition
168
+ match = re.search(r'def\s+([a-zA-Z_][a-zA-Z0-9_]*)\((.*?)\):', code)
169
+
170
+ if match:
171
+ func_name = match.group(1)
172
+ arg_list = match.group(2).split(',')
173
+ arg_list = [arg.strip() for arg in arg_list] # Clean up spaces
174
+
175
+ # Get the body of the function (everything after the definition)
176
+ func_body = code.split(':', 1)[1]
177
+
178
+ # Find which variables are actually used in the function body
179
+ used_vars = set(re.findall(r'\b([a-zA-Z_][a-zA-Z0-9_]*)\b', func_body))
180
+
181
+ # Filter out only the arguments that are actually used in the function body
182
+ filtered_args = [arg for arg in arg_list if arg in used_vars]
183
+
184
+ # Reconstruct the function definition with only the used arguments
185
+ new_func_def = f"def {func_name}({', '.join(filtered_args)}):"
186
+
187
+ # Replace the old function definition with the new one
188
+ code = re.sub(r'def\s+[a-zA-Z_][a-zA-Z0-9_]*\s*\(.*?\):', new_func_def, code)
189
+
190
+ return code
191
+
192
+ # Post-process the generated code
193
+ def post_process_code(code):
194
+ code = fix_unmatched_brackets(code)
195
+ code = validate_bracket_balance(code)
196
+ code = add_missing_imports(code)
197
+ code = remove_evalf(code)
198
+ code = handle_sum_errors(code)
199
+ code = complete_try_catch_block(code)
200
+ code = remove_extra_variables_from_function(code)
201
+ return code
202
+
203
+ # Generate the final code from LaTeX
204
+ def generate_code(latex_expression, max_length=512):
205
+ inputs = tokenizer(f"Latex Expression: {latex_expression} Solution:", return_tensors="pt").to("cuda")
206
+ outputs = model.generate(**inputs, max_length=max_length)
207
+ generated_code = tokenizer.decode(outputs[0], skip_special_tokens=True)
208
+ post_processed_code = post_process_code(generated_code)
209
+ return post_processed_code
210
+
211
+
212
+
213
 
214
  # Streamlit app layout
215
  st.title("LaTeX to Python Code Generator")