Spaces:
Runtime error
Runtime error
File size: 8,976 Bytes
c7260b4 8b11e04 5a5e927 c7260b4 8b11e04 4f37b5c c7260b4 4f37b5c c7260b4 4f37b5c c7260b4 4f37b5c c7260b4 4f37b5c e6bfaba 4f37b5c c7260b4 26c35dd c7260b4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 |
import streamlit as st
import torch
from transformers import T5ForConditionalGeneration, RobertaTokenizer
import re
import ast
# Load the fine-tuned model and tokenizer
model_repo_path = 'sabssag/Latex_to_Python_CodeT5-base'
model = T5ForConditionalGeneration.from_pretrained(model_repo_path, torch_dtype=torch.float16)
tokenizer = RobertaTokenizer.from_pretrained(model_repo_path)
model.eval()
# Fix unmatched brackets
def fix_unmatched_brackets(code):
"""
Fix unmatched brackets in the code by ensuring that all opening brackets have corresponding closing brackets,
and ensure that newline characters are handled correctly when adding missing brackets.
"""
open_brackets = {'(': 0, '[': 0, '{': 0}
close_brackets = {')': 0, ']': 0, '}': 0}
bracket_pairs = {'(': ')', '[': ']', '{': '}'}
stack = []
new_code = ""
# Iterate through the code to track unmatched brackets and their positions
for i, char in enumerate(code):
if char in open_brackets:
stack.append(char)
open_brackets[char] += 1
elif char in close_brackets:
if stack and bracket_pairs[stack[-1]] == char:
stack.pop() # Matching bracket
else:
# Unmatched closing bracket found, but we need to check if it's valid
if stack:
# If we have an unmatched opening bracket, fix it by adding the correct closing bracket
new_code += bracket_pairs[stack.pop()]
else:
# If no matching opening bracket, just skip adding the closing bracket
continue
new_code += char
# Append missing closing brackets at the end
while stack:
last_char = new_code[-1]
# If the last character is a newline, remove it before appending the closing bracket
if last_char == '\n':
new_code = new_code[:-1]
new_code += bracket_pairs[stack.pop()]
return new_code
# Validate and correct bracket balance
def validate_bracket_balance(code):
"""
Validates if brackets are balanced and fixes common issues.
"""
stack = []
bracket_map = {')': '(', ']': '[', '}': '{'}
for i, char in enumerate(code):
if char in bracket_map.values():
stack.append(char)
elif char in bracket_map:
if stack and stack[-1] == bracket_map[char]:
stack.pop()
else:
code = code[:i] + '#' + code[i+1:] # Comment out the misaligned closing bracket
break
while stack:
code += { '(': ')', '[': ']', '{': '}' }[stack.pop()]
return code
# Add missing imports based on used functions
def add_missing_imports(code):
"""
Detect missing sympy or numpy imports based on used functions in the code.
Also fixes incorrect import statements like `from sympy import, pi`.
"""
sympy_funcs = {
"cot", "sqrt", "pi", "sin", "cos", "tan", "log", "Abs", "exp",
"factorial", "csc", "sec", "asin", "acos", "atan", "Eq", "symbols", "Function", "Derivative"
}
# Detect function calls and existing imports
function_pattern = r'\b([a-zA-Z_][a-zA-Z0-9_]*)\b'
used_functions = set(re.findall(function_pattern, code))
# Match 'from sympy import' statements
existing_imports = re.findall(r'from sympy import ([a-zA-Z_, ]+)', code)
# Flatten the existing imports set by splitting any comma-separated imports
existing_imports_set = {imp.strip() for ex_imp in existing_imports for imp in ex_imp.split(',')}
# Find which sympy functions are required but not yet imported
required_imports = used_functions.intersection(sympy_funcs) - existing_imports_set
# If there are required imports, we will just add them on top of the existing imports
if required_imports:
# Consolidate all imports into one line, without adding duplicate imports
import_statement = f"from sympy import {', '.join(sorted(existing_imports_set | required_imports))}\n"
# Remove the current sympy imports with a consolidated import statement
code = re.sub(r'from sympy import [a-zA-Z_, ]+\n', '', code)
code = import_statement + code
# Fully remove incorrect import statements (like `from sympy import, pi`)
code = re.sub(r'from sympy import,\s*.*\n', '', code)
# Add numpy import if necessary
if "np." in code and "import numpy as np" not in code:
code = "import numpy as np\n" + code
return code
# Enhanced removal of evalf() calls, handling malformed cases
def remove_evalf(code):
"""
Remove all occurrences of .evalf() from the code, including cases where it's misplaced or malformed.
"""
# Remove evalf calls in a more comprehensive way
code = re.sub(r'\.evalf\(\)', '', code) # Regular evalf calls
code = re.sub(r'\*evalf\(\)', '', code) # Cases like `*evalf()`
# Ensure parentheses remain balanced even after removing evalf()
code = fix_unmatched_brackets(code)
return code
def handle_sum_errors(code):
"""
Detects and fixes cases where `sum()` is applied to non-iterable objects.
"""
# Regex to detect invalid use of sum
invalid_sum_pattern = r'sum\(([^()]+)\)'
# Replace invalid sum usage with the content inside the sum (since it's non-iterable)
code = re.sub(invalid_sum_pattern, r'\1', code)
return code
def complete_try_catch_block(code):
"""
Ensure that the try block in the code is followed by a valid except block.
If missing, a generic except block will be added.
"""
# Check if there's a 'try' block without an 'except' block
if 'try:' in code and 'except' not in code:
# Add a generic except block to catch any exceptions
code = re.sub(r'try:', r'try:\n pass\n except Exception as e:\n print(f"Error: {e}")', code)
return code
import re
def remove_extra_variables_from_function(code):
"""
Remove extra variables from the function definition list of arguments
that are not used in the function body.
"""
# Find the function definition
match = re.search(r'def\s+([a-zA-Z_][a-zA-Z0-9_]*)\((.*?)\):', code)
if match:
func_name = match.group(1)
arg_list = match.group(2).split(',')
arg_list = [arg.strip() for arg in arg_list] # Clean up spaces
# Get the body of the function (everything after the definition)
func_body = code.split(':', 1)[1]
# Find which variables are actually used in the function body
used_vars = set(re.findall(r'\b([a-zA-Z_][a-zA-Z0-9_]*)\b', func_body))
# Filter out only the arguments that are actually used in the function body
filtered_args = [arg for arg in arg_list if arg in used_vars]
# Reconstruct the function definition with only the used arguments
new_func_def = f"def {func_name}({', '.join(filtered_args)}):"
# Replace the old function definition with the new one
code = re.sub(r'def\s+[a-zA-Z_][a-zA-Z0-9_]*\s*\(.*?\):', new_func_def, code)
return code
# Post-process the generated code
def post_process_code(code):
code = fix_unmatched_brackets(code)
code = validate_bracket_balance(code)
code = add_missing_imports(code)
code = remove_evalf(code)
code = handle_sum_errors(code)
code = complete_try_catch_block(code)
code = remove_extra_variables_from_function(code)
return code
# Generate the final code from LaTeX
def generate_code(latex_expression, max_length=512):
inputs = tokenizer(f"Latex Expression: {latex_expression} Solution:", return_tensors="pt")
outputs = model.generate(**inputs, max_length=max_length)
generated_code = tokenizer.decode(outputs[0], skip_special_tokens=True)
post_processed_code = post_process_code(generated_code)
return post_processed_code
# Streamlit app layout
st.title("LaTeX to Python Code Generator")
# Define session state keys
if 'latex_expr' not in st.session_state:
st.session_state.latex_expr = ""
# User input for LaTeX expression
latex_input = st.text_area("Enter the LaTeX Expression", value=st.session_state.latex_expr, height=150)
# Update session state with the new LaTeX expression
if st.button("Generate Code"):
if latex_input:
st.session_state.latex_expr = latex_input
with st.spinner("Generating Python Code..."):
try:
# Correct function name here
generated_code = generate_code(latex_expression=st.session_state.latex_expr)
# Display the generated code
st.subheader("Generated Python Code")
st.code(generated_code, language='python')
except Exception as e:
st.error(f"Error during code generation: {e}")
else:
st.warning("Please enter a LaTeX expression to generate Python code.") |