Update app.py
Browse files
app.py
CHANGED
@@ -1,277 +1,280 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
from huggingface_hub import InferenceClient
|
3 |
-
import pandas as pd
|
4 |
-
import io
|
5 |
-
import mysql.connector
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
- **
|
20 |
-
- **
|
21 |
-
|
22 |
-
**
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
"
|
42 |
-
prompt =
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
**Example
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
cursor.
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
connection
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
cursor.
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
return
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
"
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
#
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
{
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
```
|
179 |
-
""
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
"
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
"""
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
#
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
"
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
return
|
250 |
-
|
251 |
-
elif "
|
252 |
-
return
|
253 |
-
|
254 |
-
elif "
|
255 |
-
|
256 |
-
|
257 |
-
elif "
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
gr.Textbox(label="
|
269 |
-
gr.Textbox(label="
|
270 |
-
gr.
|
271 |
-
|
272 |
-
|
273 |
-
)
|
274 |
-
|
275 |
-
|
276 |
-
|
|
|
|
|
|
|
277 |
# print(create_table("create a SQL student table","root", "123456", "localhost", "demo"))
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from huggingface_hub import InferenceClient
|
3 |
+
import pandas as pd
|
4 |
+
import io
|
5 |
+
import mysql.connector
|
6 |
+
import os
|
7 |
+
|
8 |
+
HF_API_KEY = os.getenv("HF_API_KEY")
|
9 |
+
|
10 |
+
# Hugging Face API Key (Replace with your actual key)
|
11 |
+
client = InferenceClient(model="mistralai/Mistral-7B-Instruct-v0.3", token=HF_API_KEY)
|
12 |
+
|
13 |
+
def classify_task(user_input):
|
14 |
+
"""Classifies user task using Mistral."""
|
15 |
+
prompt = f"""
|
16 |
+
You are an AI assistant that classifies user requests related to SQL and databases.
|
17 |
+
Your task is to categorize the input into one of the following options:
|
18 |
+
|
19 |
+
- **generate_sql** → If the user asks to generate **SQL syntax**, like SELECT, INSERT, CREATE TABLE, etc.
|
20 |
+
- **create_table** → If the user explicitly wants to **create** a database/table on a local MySQL server.
|
21 |
+
- **generate_demo_data_db** → If the user wants to insert demo data into a database.
|
22 |
+
- **generate_demo_data_csv** → If the user wants to generate demo data in CSV format.
|
23 |
+
- **analyze_data** → If the user asks for insights, trends, or statistical analysis of data.
|
24 |
+
|
25 |
+
**Examples:**
|
26 |
+
1. "Give me SQL syntax to create a student table" → **generate_sql**
|
27 |
+
2. "Create a student table in my database" → **create_table**
|
28 |
+
3. "Insert some demo data in my database" → **generate_demo_data_db**
|
29 |
+
4. "Generate sample student data in CSV format" → **generate_demo_data_csv**
|
30 |
+
5. "Analyze student marks and trends" → **analyze_data**
|
31 |
+
|
32 |
+
**User Input:** {user_input}
|
33 |
+
|
34 |
+
**Output Format:** Return only the category name without any explanations.
|
35 |
+
"""
|
36 |
+
|
37 |
+
response = client.text_generation(prompt, max_new_tokens=20).strip()
|
38 |
+
return response
|
39 |
+
def generate_sql_query(user_input):
|
40 |
+
"""Generates SQL queries using Mistral."""
|
41 |
+
prompt = f"Generate SQL syntax for: {user_input}"
|
42 |
+
return client.text_generation(prompt, max_new_tokens=200).strip()
|
43 |
+
def generate_sql_query_for_create(user_input):
|
44 |
+
"""Generates SQL queries using Mistral."""
|
45 |
+
prompt = f"""
|
46 |
+
Generate **only** the SQL syntax for: {user_input}
|
47 |
+
|
48 |
+
**Rules:**
|
49 |
+
- No explanations, no bullet points, no extra text.
|
50 |
+
- Return **only valid SQL**.
|
51 |
+
|
52 |
+
**Example Input:**
|
53 |
+
"Create a student table with name, age, and email."
|
54 |
+
|
55 |
+
**Example Output:**
|
56 |
+
```sql
|
57 |
+
CREATE TABLE student (
|
58 |
+
student_id INT PRIMARY KEY AUTO_INCREMENT,
|
59 |
+
name VARCHAR(100) NOT NULL,
|
60 |
+
age INT NOT NULL,
|
61 |
+
email VARCHAR(255) UNIQUE NOT NULL
|
62 |
+
);
|
63 |
+
```
|
64 |
+
"""
|
65 |
+
|
66 |
+
response = client.text_generation(prompt, max_new_tokens=200).strip()
|
67 |
+
|
68 |
+
# Remove unnecessary text (if any)
|
69 |
+
if "```sql" in response:
|
70 |
+
response = response.split("```sql")[1].split("```")[0].strip()
|
71 |
+
|
72 |
+
return response
|
73 |
+
|
74 |
+
import pymysql
|
75 |
+
|
76 |
+
def create_table(user_input, db_user, db_pass, db_host, db_name):
|
77 |
+
try:
|
78 |
+
# Validate inputs
|
79 |
+
if not all([db_user, db_pass, db_host, db_name, user_input]):
|
80 |
+
return "Please provide all required inputs (database credentials and table structure).", None
|
81 |
+
|
82 |
+
# Generate SQL schema using Mistral
|
83 |
+
|
84 |
+
schema_response = generate_sql_query_for_create(user_input)
|
85 |
+
print(schema_response)
|
86 |
+
# Validate schema using sqlparse
|
87 |
+
parsed_schema = sqlparse.parse(schema_response)
|
88 |
+
if not parsed_schema:
|
89 |
+
return "Error: Could not generate a valid table schema.", None
|
90 |
+
|
91 |
+
# Connect to MySQL Server
|
92 |
+
connection = pymysql.connect(host=db_host, user=db_user, password=db_pass)
|
93 |
+
cursor = connection.cursor()
|
94 |
+
|
95 |
+
# Create Database if it doesn't exist
|
96 |
+
cursor.execute(f"CREATE DATABASE IF NOT EXISTS {db_name}")
|
97 |
+
connection.commit()
|
98 |
+
connection.close()
|
99 |
+
|
100 |
+
# Connect to the specified database
|
101 |
+
connection = pymysql.connect(host=db_host, user=db_user, password=db_pass, database=db_name)
|
102 |
+
cursor = connection.cursor()
|
103 |
+
|
104 |
+
# Execute the generated CREATE TABLE statement
|
105 |
+
cursor.execute(schema_response)
|
106 |
+
connection.commit()
|
107 |
+
|
108 |
+
return "Table created successfully.", None
|
109 |
+
|
110 |
+
except pymysql.MySQLError as err:
|
111 |
+
return f"Error: {err}", None
|
112 |
+
|
113 |
+
finally:
|
114 |
+
if 'connection' in locals() and connection.open:
|
115 |
+
cursor.close()
|
116 |
+
connection.close()
|
117 |
+
|
118 |
+
import mysql.connector
|
119 |
+
import re
|
120 |
+
import sqlparse # Install via: pip install sqlparse
|
121 |
+
|
122 |
+
def generate_demo_data_db(user_input, db_user, db_pass, db_host, db_name, num_rows=10):
|
123 |
+
"""Generates and inserts structured demo data into a database using LLM."""
|
124 |
+
|
125 |
+
if not all([db_user, db_pass, db_name]):
|
126 |
+
return "Please provide database credentials.", None
|
127 |
+
|
128 |
+
# Generate column definitions using LLM
|
129 |
+
schema_prompt = f"""
|
130 |
+
Extract column names and types from the following request:
|
131 |
+
|
132 |
+
"{user_input}"
|
133 |
+
|
134 |
+
**Output Format:**
|
135 |
+
- The first column should be an "ID" column (INTEGER, PRIMARY KEY).
|
136 |
+
- Provide appropriate SQL data types (VARCHAR(100) for text, INT for numbers).
|
137 |
+
- Use proper SQL syntax. No explanations.
|
138 |
+
|
139 |
+
Example Output:
|
140 |
+
```
|
141 |
+
CREATE TABLE demo (
|
142 |
+
ID INT PRIMARY KEY,
|
143 |
+
Name VARCHAR(100),
|
144 |
+
Age INT
|
145 |
+
);
|
146 |
+
```
|
147 |
+
"""
|
148 |
+
schema_response = client.text_generation(schema_prompt, max_new_tokens=200).strip()
|
149 |
+
|
150 |
+
# Validate schema using sqlparse
|
151 |
+
parsed_schema = sqlparse.parse(schema_response)
|
152 |
+
if not parsed_schema:
|
153 |
+
return "Error: Could not generate a valid table schema.", None
|
154 |
+
|
155 |
+
# Extract table schema
|
156 |
+
table_schema = schema_response.replace("CREATE TABLE demo (", "").replace(");", "").strip()
|
157 |
+
|
158 |
+
# Connect to MySQL and create the table dynamically
|
159 |
+
connection = mysql.connector.connect(host=db_host, user=db_user, password=db_pass, database=db_name)
|
160 |
+
cursor = connection.cursor()
|
161 |
+
cursor.execute(f"CREATE TABLE IF NOT EXISTS demo ({table_schema})")
|
162 |
+
|
163 |
+
# Generate demo data using LLM
|
164 |
+
data_prompt = f"""
|
165 |
+
Generate {num_rows} rows of structured demo data for this table schema:
|
166 |
+
|
167 |
+
```
|
168 |
+
{schema_response}
|
169 |
+
```
|
170 |
+
|
171 |
+
**Output Format:**
|
172 |
+
- Return valid SQL INSERT statements.
|
173 |
+
- Ensure all values match their respective column types.
|
174 |
+
- Use double quotes ("") for text values.
|
175 |
+
- No explanations, just raw SQL.
|
176 |
+
|
177 |
+
Example Output:
|
178 |
+
```
|
179 |
+
INSERT INTO demo VALUES (1, "John Doe", 25);
|
180 |
+
INSERT INTO demo VALUES (2, "Jane Smith", 30);
|
181 |
+
```
|
182 |
+
"""
|
183 |
+
data_response = client.text_generation(data_prompt, max_new_tokens=1000).strip()
|
184 |
+
|
185 |
+
# Extract SQL INSERT statements using a better regex
|
186 |
+
insert_statements = re.findall(r'INSERT INTO demo VALUES \([^)]+\);', data_response, re.DOTALL)
|
187 |
+
if not insert_statements:
|
188 |
+
return "Error: Could not generate valid data.", None
|
189 |
+
|
190 |
+
# Insert data into the database
|
191 |
+
for statement in insert_statements:
|
192 |
+
cursor.execute(statement)
|
193 |
+
|
194 |
+
connection.commit()
|
195 |
+
connection.close()
|
196 |
+
|
197 |
+
return "Demo data inserted into the database successfully.", None
|
198 |
+
def generate_demo_data_csv(user_input, num_rows=10):
|
199 |
+
"""Generates realistic demo data using the LLM in valid CSV format."""
|
200 |
+
|
201 |
+
prompt = f"""
|
202 |
+
Generate a structured dataset with {num_rows} rows based on the following request:
|
203 |
+
|
204 |
+
"{user_input}"
|
205 |
+
|
206 |
+
**Output Format:**
|
207 |
+
- Ensure the response is in **valid CSV format** (comma-separated).
|
208 |
+
- The **first row** must be column headers.
|
209 |
+
- Use **double quotes for text values** to avoid formatting issues.
|
210 |
+
- Do **not** include explanations—just the raw CSV data.
|
211 |
+
|
212 |
+
Example Output:
|
213 |
+
|
214 |
+
"ID","Name","Age","Email"
|
215 |
+
"1","John Doe","25","john.doe@example.com"
|
216 |
+
"2","Jane Smith","30","jane.smith@example.com"
|
217 |
+
|
218 |
+
"""
|
219 |
+
|
220 |
+
# Get LLM response
|
221 |
+
response = client.text_generation(prompt, max_new_tokens=10000).strip()
|
222 |
+
|
223 |
+
# Ensure we extract only the CSV part (some models may add explanations)
|
224 |
+
csv_start = response.find('"ID"') # Find where the CSV starts
|
225 |
+
if csv_start != -1:
|
226 |
+
response = response[csv_start:] # Remove anything before the CSV
|
227 |
+
|
228 |
+
# Convert to DataFrame
|
229 |
+
try:
|
230 |
+
df = pd.read_csv(io.StringIO(response)) # Read as CSV
|
231 |
+
except Exception as e:
|
232 |
+
return f"Error: Invalid CSV format. {str(e)}", None
|
233 |
+
|
234 |
+
# Save to a CSV file
|
235 |
+
file_path = "generated_data.csv"
|
236 |
+
df.to_csv(file_path, index=False)
|
237 |
+
|
238 |
+
return "Demo data generated as CSV.", file_path # Return file path
|
239 |
+
|
240 |
+
def analyze_data(user_input):
|
241 |
+
"""Analyzes data using Mistral."""
|
242 |
+
prompt = f"Analyze this data: {user_input}"
|
243 |
+
return client.text_generation(prompt, max_new_tokens=200).strip()
|
244 |
+
|
245 |
+
def sql_chatbot(user_input, db_user=None, db_pass=None, db_host="localhost", db_name=None, num_rows=10):
|
246 |
+
task = classify_task(user_input)
|
247 |
+
|
248 |
+
if "generate_sql" in task:
|
249 |
+
return generate_sql_query(user_input), None
|
250 |
+
|
251 |
+
elif "create_table" in task:
|
252 |
+
return create_table(user_input, db_user, db_pass, db_host, db_name)
|
253 |
+
|
254 |
+
elif "generate_demo_data_db" in task:
|
255 |
+
return generate_demo_data_db(user_input, db_user, db_pass, db_host, db_name, num_rows)
|
256 |
+
|
257 |
+
elif "generate_demo_data_csv" in task:
|
258 |
+
response, file_path = generate_demo_data_csv(user_input, num_rows)
|
259 |
+
return response, file_path
|
260 |
+
elif "analyze_data" in task:
|
261 |
+
return analyze_data(user_input), None
|
262 |
+
|
263 |
+
return f"task:{task} \n I could not understand your request.", None
|
264 |
+
|
265 |
+
iface = gr.Interface(
|
266 |
+
fn=sql_chatbot,
|
267 |
+
inputs=[
|
268 |
+
gr.Textbox(label="User Input"),
|
269 |
+
gr.Textbox(label="MySQL Username", interactive=True),
|
270 |
+
gr.Textbox(label="MySQL Password", interactive=True, type="password"),
|
271 |
+
gr.Textbox(label="MySQL Host", interactive=True, value="localhost"),
|
272 |
+
gr.Textbox(label="Database Name", interactive=True),
|
273 |
+
gr.Number(label="Number of Rows", interactive=True, value=10, precision=0)
|
274 |
+
],
|
275 |
+
outputs=[gr.Textbox(label="Response"), gr.File(label="File Output")]
|
276 |
+
)
|
277 |
+
|
278 |
+
iface.launch()
|
279 |
+
# print("hi")
|
280 |
# print(create_table("create a SQL student table","root", "123456", "localhost", "demo"))
|