Garvitj commited on
Commit
c66b450
·
verified ·
1 Parent(s): 3e3098b

Upload new.py

Browse files
Files changed (1) hide show
  1. new.py +277 -0
new.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import InferenceClient
3
+ import pandas as pd
4
+ import io
5
+ import mysql.connector
6
+
7
+ # Hugging Face API Key (Replace with your actual key)
8
+ client = InferenceClient(model="mistralai/Mistral-7B-Instruct-v0.3", token=HF_API_KEY)
9
+
10
+ def classify_task(user_input):
11
+ """Classifies user task using Mistral."""
12
+ prompt = f"""
13
+ You are an AI assistant that classifies user requests related to SQL and databases.
14
+ Your task is to categorize the input into one of the following options:
15
+
16
+ - **generate_sql** → If the user asks to generate **SQL syntax**, like SELECT, INSERT, CREATE TABLE, etc.
17
+ - **create_table** → If the user explicitly wants to **create** a database/table on a local MySQL server.
18
+ - **generate_demo_data_db** → If the user wants to insert demo data into a database.
19
+ - **generate_demo_data_csv** → If the user wants to generate demo data in CSV format.
20
+ - **analyze_data** → If the user asks for insights, trends, or statistical analysis of data.
21
+
22
+ **Examples:**
23
+ 1. "Give me SQL syntax to create a student table" → **generate_sql**
24
+ 2. "Create a student table in my database" → **create_table**
25
+ 3. "Insert some demo data in my database" → **generate_demo_data_db**
26
+ 4. "Generate sample student data in CSV format" → **generate_demo_data_csv**
27
+ 5. "Analyze student marks and trends" → **analyze_data**
28
+
29
+ **User Input:** {user_input}
30
+
31
+ **Output Format:** Return only the category name without any explanations.
32
+ """
33
+
34
+ response = client.text_generation(prompt, max_new_tokens=20).strip()
35
+ return response
36
+ def generate_sql_query(user_input):
37
+ """Generates SQL queries using Mistral."""
38
+ prompt = f"Generate SQL syntax for: {user_input}"
39
+ return client.text_generation(prompt, max_new_tokens=200).strip()
40
+ def generate_sql_query_for_create(user_input):
41
+ """Generates SQL queries using Mistral."""
42
+ prompt = f"""
43
+ Generate **only** the SQL syntax for: {user_input}
44
+
45
+ **Rules:**
46
+ - No explanations, no bullet points, no extra text.
47
+ - Return **only valid SQL**.
48
+
49
+ **Example Input:**
50
+ "Create a student table with name, age, and email."
51
+
52
+ **Example Output:**
53
+ ```sql
54
+ CREATE TABLE student (
55
+ student_id INT PRIMARY KEY AUTO_INCREMENT,
56
+ name VARCHAR(100) NOT NULL,
57
+ age INT NOT NULL,
58
+ email VARCHAR(255) UNIQUE NOT NULL
59
+ );
60
+ ```
61
+ """
62
+
63
+ response = client.text_generation(prompt, max_new_tokens=200).strip()
64
+
65
+ # Remove unnecessary text (if any)
66
+ if "```sql" in response:
67
+ response = response.split("```sql")[1].split("```")[0].strip()
68
+
69
+ return response
70
+
71
+ import pymysql
72
+
73
+ def create_table(user_input, db_user, db_pass, db_host, db_name):
74
+ try:
75
+ # Validate inputs
76
+ if not all([db_user, db_pass, db_host, db_name, user_input]):
77
+ return "Please provide all required inputs (database credentials and table structure).", None
78
+
79
+ # Generate SQL schema using Mistral
80
+
81
+ schema_response = generate_sql_query_for_create(user_input)
82
+ print(schema_response)
83
+ # Validate schema using sqlparse
84
+ parsed_schema = sqlparse.parse(schema_response)
85
+ if not parsed_schema:
86
+ return "Error: Could not generate a valid table schema.", None
87
+
88
+ # Connect to MySQL Server
89
+ connection = pymysql.connect(host=db_host, user=db_user, password=db_pass)
90
+ cursor = connection.cursor()
91
+
92
+ # Create Database if it doesn't exist
93
+ cursor.execute(f"CREATE DATABASE IF NOT EXISTS {db_name}")
94
+ connection.commit()
95
+ connection.close()
96
+
97
+ # Connect to the specified database
98
+ connection = pymysql.connect(host=db_host, user=db_user, password=db_pass, database=db_name)
99
+ cursor = connection.cursor()
100
+
101
+ # Execute the generated CREATE TABLE statement
102
+ cursor.execute(schema_response)
103
+ connection.commit()
104
+
105
+ return "Table created successfully.", None
106
+
107
+ except pymysql.MySQLError as err:
108
+ return f"Error: {err}", None
109
+
110
+ finally:
111
+ if 'connection' in locals() and connection.open:
112
+ cursor.close()
113
+ connection.close()
114
+
115
+ import mysql.connector
116
+ import re
117
+ import sqlparse # Install via: pip install sqlparse
118
+
119
+ def generate_demo_data_db(user_input, db_user, db_pass, db_host, db_name, num_rows=10):
120
+ """Generates and inserts structured demo data into a database using LLM."""
121
+
122
+ if not all([db_user, db_pass, db_name]):
123
+ return "Please provide database credentials.", None
124
+
125
+ # Generate column definitions using LLM
126
+ schema_prompt = f"""
127
+ Extract column names and types from the following request:
128
+
129
+ "{user_input}"
130
+
131
+ **Output Format:**
132
+ - The first column should be an "ID" column (INTEGER, PRIMARY KEY).
133
+ - Provide appropriate SQL data types (VARCHAR(100) for text, INT for numbers).
134
+ - Use proper SQL syntax. No explanations.
135
+
136
+ Example Output:
137
+ ```
138
+ CREATE TABLE demo (
139
+ ID INT PRIMARY KEY,
140
+ Name VARCHAR(100),
141
+ Age INT
142
+ );
143
+ ```
144
+ """
145
+ schema_response = client.text_generation(schema_prompt, max_new_tokens=200).strip()
146
+
147
+ # Validate schema using sqlparse
148
+ parsed_schema = sqlparse.parse(schema_response)
149
+ if not parsed_schema:
150
+ return "Error: Could not generate a valid table schema.", None
151
+
152
+ # Extract table schema
153
+ table_schema = schema_response.replace("CREATE TABLE demo (", "").replace(");", "").strip()
154
+
155
+ # Connect to MySQL and create the table dynamically
156
+ connection = mysql.connector.connect(host=db_host, user=db_user, password=db_pass, database=db_name)
157
+ cursor = connection.cursor()
158
+ cursor.execute(f"CREATE TABLE IF NOT EXISTS demo ({table_schema})")
159
+
160
+ # Generate demo data using LLM
161
+ data_prompt = f"""
162
+ Generate {num_rows} rows of structured demo data for this table schema:
163
+
164
+ ```
165
+ {schema_response}
166
+ ```
167
+
168
+ **Output Format:**
169
+ - Return valid SQL INSERT statements.
170
+ - Ensure all values match their respective column types.
171
+ - Use double quotes ("") for text values.
172
+ - No explanations, just raw SQL.
173
+
174
+ Example Output:
175
+ ```
176
+ INSERT INTO demo VALUES (1, "John Doe", 25);
177
+ INSERT INTO demo VALUES (2, "Jane Smith", 30);
178
+ ```
179
+ """
180
+ data_response = client.text_generation(data_prompt, max_new_tokens=1000).strip()
181
+
182
+ # Extract SQL INSERT statements using a better regex
183
+ insert_statements = re.findall(r'INSERT INTO demo VALUES \([^)]+\);', data_response, re.DOTALL)
184
+ if not insert_statements:
185
+ return "Error: Could not generate valid data.", None
186
+
187
+ # Insert data into the database
188
+ for statement in insert_statements:
189
+ cursor.execute(statement)
190
+
191
+ connection.commit()
192
+ connection.close()
193
+
194
+ return "Demo data inserted into the database successfully.", None
195
+ def generate_demo_data_csv(user_input, num_rows=10):
196
+ """Generates realistic demo data using the LLM in valid CSV format."""
197
+
198
+ prompt = f"""
199
+ Generate a structured dataset with {num_rows} rows based on the following request:
200
+
201
+ "{user_input}"
202
+
203
+ **Output Format:**
204
+ - Ensure the response is in **valid CSV format** (comma-separated).
205
+ - The **first row** must be column headers.
206
+ - Use **double quotes for text values** to avoid formatting issues.
207
+ - Do **not** include explanations—just the raw CSV data.
208
+
209
+ Example Output:
210
+
211
+ "ID","Name","Age","Email"
212
+ "1","John Doe","25","john.doe@example.com"
213
+ "2","Jane Smith","30","jane.smith@example.com"
214
+
215
+ """
216
+
217
+ # Get LLM response
218
+ response = client.text_generation(prompt, max_new_tokens=10000).strip()
219
+
220
+ # Ensure we extract only the CSV part (some models may add explanations)
221
+ csv_start = response.find('"ID"') # Find where the CSV starts
222
+ if csv_start != -1:
223
+ response = response[csv_start:] # Remove anything before the CSV
224
+
225
+ # Convert to DataFrame
226
+ try:
227
+ df = pd.read_csv(io.StringIO(response)) # Read as CSV
228
+ except Exception as e:
229
+ return f"Error: Invalid CSV format. {str(e)}", None
230
+
231
+ # Save to a CSV file
232
+ file_path = "generated_data.csv"
233
+ df.to_csv(file_path, index=False)
234
+
235
+ return "Demo data generated as CSV.", file_path # Return file path
236
+
237
+ def analyze_data(user_input):
238
+ """Analyzes data using Mistral."""
239
+ prompt = f"Analyze this data: {user_input}"
240
+ return client.text_generation(prompt, max_new_tokens=200).strip()
241
+
242
+ def sql_chatbot(user_input, db_user=None, db_pass=None, db_host="localhost", db_name=None, num_rows=10):
243
+ task = classify_task(user_input)
244
+
245
+ if "generate_sql" in task:
246
+ return generate_sql_query(user_input), None
247
+
248
+ elif "create_table" in task:
249
+ return create_table(user_input, db_user, db_pass, db_host, db_name)
250
+
251
+ elif "generate_demo_data_db" in task:
252
+ return generate_demo_data_db(user_input, db_user, db_pass, db_host, db_name, num_rows)
253
+
254
+ elif "generate_demo_data_csv" in task:
255
+ response, file_path = generate_demo_data_csv(user_input, num_rows)
256
+ return response, file_path
257
+ elif "analyze_data" in task:
258
+ return analyze_data(user_input), None
259
+
260
+ return f"task:{task} \n I could not understand your request.", None
261
+
262
+ iface = gr.Interface(
263
+ fn=sql_chatbot,
264
+ inputs=[
265
+ gr.Textbox(label="User Input"),
266
+ gr.Textbox(label="MySQL Username", interactive=True),
267
+ gr.Textbox(label="MySQL Password", interactive=True, type="password"),
268
+ gr.Textbox(label="MySQL Host", interactive=True, value="localhost"),
269
+ gr.Textbox(label="Database Name", interactive=True),
270
+ gr.Number(label="Number of Rows", interactive=True, value=10, precision=0)
271
+ ],
272
+ outputs=[gr.Textbox(label="Response"), gr.File(label="File Output")]
273
+ )
274
+
275
+ iface.launch()
276
+ # print("hi")
277
+ # print(create_table("create a SQL student table","root", "123456", "localhost", "demo"))