AhmedSSoliman commited on
Commit
3671d71
1 Parent(s): cd090e9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +136 -0
app.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import re
3
+ import pandas as pd
4
+ import plotly.express as px
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
+ from prettytable import PrettyTable
7
+ import streamlit as st
8
+
9
+ st.title('Code Generation on the CoNaLa Dataset')
10
+
11
+ class CodeGenerator:
12
+ def __init__(self):
13
+ self.tokenizer = AutoTokenizer.from_pretrained("AhmedSSoliman/MarianCG-CoNaLa-Large")
14
+ self.model = AutoModelForSeq2SeqLM.from_pretrained("AhmedSSoliman/MarianCG-CoNaLa-Large")
15
+
16
+ def generate_code(self, nl_input):
17
+ input_ids = self.tokenizer.encode(nl_input, return_tensors="pt")
18
+ output_ids = self.model.generate(input_ids)
19
+ output_code = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
20
+ return output_code
21
+
22
+
23
+
24
+
25
+
26
+ def check_code(self, code):
27
+ with open("temp.py", "w") as f:
28
+ f.write(code)
29
+ result = subprocess.run(["flake8", "temp.py"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
30
+ output = result.stdout.decode()
31
+ error = result.stderr.decode()
32
+
33
+ return self._process_output(output, error)
34
+
35
+ def check_code_list(self, code_list):
36
+ output = ""
37
+ error = ""
38
+ for code in code_list:
39
+ with open("temp.py", "w") as f:
40
+ f.write(code)
41
+ result = subprocess.run(["flake8", "--count", "temp.py"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
42
+ output += result.stdout.decode()
43
+ error += result.stderr.decode()
44
+
45
+ return self._process_output(output, error)
46
+
47
+ def _process_output(self, output, error):
48
+ if output:
49
+ output_counts = self._get_error_counts(output)
50
+ self.show_variables_in_table(output_counts, output)
51
+ self.visualize_all_errors(output_counts)
52
+ self.visualize_error_types(output_counts)
53
+
54
+ return self._format_error_counts(output_counts)
55
+ else:
56
+ error_counts = self._get_error_counts(error)
57
+ self.show_variables_in_table(output_counts, output)
58
+ self.visualize_all_errors(error_counts)
59
+ self.visualize_error_types(error_counts)
60
+
61
+ return self._format_error_counts(error_counts)
62
+
63
+ def _get_error_counts(self, output):
64
+ error_counts = {}
65
+ error_messages = re.findall(r"temp.py:(\d+):\d+: (\w\d+)", output)
66
+ for message in error_messages:
67
+ error_type = message[1]
68
+ if error_type in error_counts:
69
+ error_counts[error_type] += 1
70
+ else:
71
+ error_counts[error_type] = 1
72
+ return error_counts
73
+
74
+ def _format_error_counts(self, error_counts):
75
+ error_message = "\n".join([f"{error_type}: {count}" for error_type, count in error_counts.items()])
76
+ return error_message
77
+
78
+ def visualize_all_errors(self, error_counts):
79
+ for error_type, count in error_counts.items():
80
+ print(f"{error_type}: {count}\n")
81
+
82
+
83
+ def visualize_error_types(self, error_counts):
84
+ df = pd.DataFrame({'Error Type': list(error_counts.keys()), 'Count': list(error_counts.values())})
85
+ fig = px.bar(df, x='Count', y='Error Type', title='Error Occurrences in The Generated Code')
86
+ fig.update_layout(
87
+ title={
88
+ 'text': "Error Occurrences in The Generated Code",
89
+ 'x': 0.5,
90
+ 'y': 0.96,
91
+ 'xanchor': 'center',
92
+ 'yanchor': 'top'
93
+ },
94
+ xaxis_title="Error Counts",
95
+ yaxis_title="Error Codes"
96
+ )
97
+ fig.show()
98
+
99
+ def show_variables_in_table(self, output_counts, output):
100
+ table = PrettyTable()
101
+ table.field_names = ["Error Code", "Message"]
102
+ table.add_row([output_counts, output])
103
+ #table.add_row(["Error", error])
104
+ print(table)
105
+
106
+ def display_variables(self, output, error):
107
+ output_df = pd.DataFrame({"Output": [output]})
108
+ error_df = pd.DataFrame({"Error": [error]})
109
+ display(pd.concat([output_df, error_df], axis=1))
110
+
111
+
112
+
113
+ code_generator = CodeGenerator()
114
+
115
+
116
+ # Streamlit app
117
+ def main():
118
+ st.title('Code Generator and Error Checker')
119
+ nl_input = st.text_area('Enter natural language input for code generation')
120
+ if st.button('Generate Code'):
121
+ # Generate code
122
+ output_code = code_generator.generate_code(nl_input)
123
+ st.subheader('Generated Code')
124
+ st.code(output_code, language='python')
125
+
126
+ # Check code for errors
127
+ st.subheader('Error Check')
128
+ error_message = code_generator.check_code(output_code)
129
+ st.write('Error Counts:')
130
+ st.write(error_message)
131
+
132
+
133
+ if __name__ == '__main__':
134
+
135
+ main()
136
+