AhmedSSoliman commited on
Commit
8fd005b
1 Parent(s): 2ba0142

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +215 -0
app.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import re
3
+ import pandas as pd
4
+ import plotly.express as px
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
+ from prettytable import PrettyTable
7
+ import streamlit as st
8
+
9
+ #st.title('Code Generation on the CoNaLa Dataset')
10
+
11
+ import subprocess
12
+ import re
13
+ import pandas as pd
14
+ import plotly.express as px
15
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
16
+ from prettytable import PrettyTable
17
+
18
+ #browser.gatherUsageStats=False
19
+
20
+ class CodeGenerator:
21
+ def __init__(self):
22
+ self.tokenizer = AutoTokenizer.from_pretrained("AhmedSSoliman/MarianCG-CoNaLa-Large")
23
+ self.model = AutoModelForSeq2SeqLM.from_pretrained("AhmedSSoliman/MarianCG-CoNaLa-Large")
24
+
25
+ def generate_code(self, nl_input):
26
+ input_ids = self.tokenizer.encode(nl_input, return_tensors="pt")
27
+ output_ids = self.model.generate(input_ids)
28
+ output_code = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
29
+ return output_code
30
+
31
+
32
+
33
+
34
+
35
+ def check_code(self, code):
36
+ with open("temp.py", "w") as f:
37
+ f.write(code)
38
+ result = subprocess.run(["flake8", "--count", "temp.py"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
39
+ output = result.stdout.decode()
40
+ error = result.stderr.decode()
41
+
42
+
43
+ return output
44
+ #return self._process_output(output, error)
45
+
46
+ def check_code_list(self, code_list):
47
+ output = ""
48
+ error = ""
49
+ for code in code_list:
50
+ with open("temp.py", "w") as f:
51
+ f.write(code)
52
+ result = subprocess.run(["flake8", "--count", "temp.py"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
53
+ output += result.stdout.decode()
54
+ error += result.stderr.decode()
55
+
56
+ return self._process_output(output, error)
57
+
58
+ def _process_output(self, output, error):
59
+ if output:
60
+ output_counts = self._get_error_counts(output)
61
+ self.show_variables_in_table(output_counts, output)
62
+ self.visualize_all_errors(output_counts)
63
+ self.visualize_error_types(output_counts)
64
+
65
+ return self._format_error_counts(output_counts)
66
+ else:
67
+ error_counts = self._get_error_counts(error)
68
+ self.show_variables_in_table(output_counts, output)
69
+ self.visualize_all_errors(error_counts)
70
+ self.visualize_error_types(error_counts)
71
+
72
+ return self._format_error_counts(error_counts)
73
+
74
+ def _get_error_counts(self, output):
75
+ error_counts = {}
76
+ error_messages = re.findall(r"temp.py:(\d+):\d+: (\w\d+)", output)
77
+ for message in error_messages:
78
+ error_type = message[1]
79
+ if error_type in error_counts:
80
+ error_counts[error_type] += 1
81
+ else:
82
+ error_counts[error_type] = 1
83
+ return error_counts
84
+
85
+ def _format_error_counts(self, error_counts):
86
+ error_message = "\n".join([f"{error_type}: {count}" for error_type, count in error_counts.items()])
87
+ return error_message
88
+
89
+ def visualize_all_errors(self, error_counts):
90
+ for error_type, count in error_counts.items():
91
+ print(f"{error_type}: {count}\n")
92
+
93
+
94
+ def visualize_error_types(self, error_counts):
95
+ df = pd.DataFrame({'Error Type': list(error_counts.keys()), 'Count': list(error_counts.values())})
96
+ fig = px.bar(df, x='Count', y='Error Type', title='Error Occurrences in The Generated Code')
97
+ fig.update_layout(
98
+ title={
99
+ 'text': "Error Occurrences in The Generated Code",
100
+ 'x': 0.5,
101
+ 'y': 0.96,
102
+ 'xanchor': 'center',
103
+ 'yanchor': 'top'
104
+ },
105
+ xaxis_title="Error Counts",
106
+ yaxis_title="Error Codes"
107
+ )
108
+ fig.show()
109
+
110
+ def show_variables_in_table(self, output_counts, output):
111
+ table = PrettyTable()
112
+ table.field_names = ["Error Code", "Message"]
113
+ table.add_row([output_counts, output])
114
+ #table.add_row(["Error", error])
115
+ print(table)
116
+
117
+ def display_variables(self, output, error):
118
+ output_df = pd.DataFrame({"Output": [output]})
119
+ error_df = pd.DataFrame({"Error": [error]})
120
+ display(pd.concat([output_df, error_df], axis=1))
121
+
122
+
123
+
124
+
125
+
126
+
127
+ import autopep8
128
+ import black
129
+ import isort
130
+ import pylint.lint
131
+ import autoimport
132
+ from yapf.yapflib.yapf_api import FormatCode # reformat a string of code
133
+
134
+ class PythonCodeFormatter:
135
+ def __init__(self, code):
136
+ self.code = code.replace('▁', ' ').strip()
137
+
138
+
139
+ def load_code_from_file(self, filename):
140
+ # Load the code to be fixed
141
+ with open(filename, 'r') as f:
142
+ self.code = f.read()
143
+
144
+ def format(self):
145
+ try:
146
+ # Use isort to sort and organize the imports
147
+ formatted_code = isort.code(self.code)
148
+
149
+ # Use black to format the code
150
+ formatted_code = black.format_str(formatted_code, mode=black.Mode())
151
+
152
+ # Use autoimport to add a missing import statement
153
+ formatted_code = autoimport.fix_code(formatted_code)
154
+
155
+ # Use autopep8 to fix any remaining issues
156
+ formatted_code = autopep8.fix_code(formatted_code)
157
+
158
+ formatted_code, changed = FormatCode(formatted_code)
159
+
160
+ return formatted_code
161
+
162
+ except RuntimeError as error:
163
+ if str(error) == 'Project root not found.':
164
+ return formatted_code
165
+ else:
166
+ raise # re-raise the error if it's not the one we're looking for
167
+
168
+ except ValueError as error:
169
+ return formatted_code
170
+
171
+ return formatted_code
172
+
173
+
174
+ def save(self, filename):
175
+ # Save the fixed code to a file
176
+ with open(filename, 'w') as f:
177
+ f.write(self.code)
178
+
179
+
180
+
181
+
182
+
183
+ code_generator = CodeGenerator()
184
+
185
+
186
+ # Streamlit app
187
+ def main():
188
+ st.title('Code Generator and Error Checker')
189
+ nl_input = st.text_area('Enter natural language input for code generation')
190
+ if st.button('Generate Code'):
191
+ # Generate code
192
+ output_code = code_generator.generate_code(nl_input)
193
+ st.subheader('Generated Code')
194
+ st.code(output_code, language='python')
195
+
196
+ # Check code for errors
197
+ st.subheader('Error Check')
198
+ error_message = code_generator.check_code(output_code)
199
+ st.write('Error Counts:')
200
+ st.write(error_message)
201
+
202
+
203
+ st.subheader('Error Correction')
204
+ formatter = PythonCodeFormatter(output_code)
205
+ formatted_code = formatter.format()
206
+ st.write('Code after correction:')
207
+ st.write(formatted_code)
208
+ #st.subheader('Code after correction:')
209
+ #st.code(formatted_code, language='python')
210
+
211
+
212
+
213
+ if __name__ == '__main__':
214
+
215
+ main()