Spaces:
Sleeping
Sleeping
barathm111
commited on
Commit
•
5473610
1
Parent(s):
4ea075c
Upload 5 files
Browse files- .env +4 -0
- .gitignore +1 -0
- Dockerfile +30 -0
- app.py +118 -0
- requirements.txt +7 -0
.env
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
DB_HOST=auth-db579.hstgr.io
|
2 |
+
DB_USER=u121769371_ki_aiml_test
|
3 |
+
DB_PASSWORD=Ki_kr_aiml@18!$#$
|
4 |
+
DB_NAME=u121769371_ki_aiml_test
|
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
.env
|
Dockerfile
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Use the official Python 3.9 image
|
2 |
+
FROM python:3.11
|
3 |
+
|
4 |
+
## Set the working directory to /code
|
5 |
+
WORKDIR /code
|
6 |
+
|
7 |
+
## Copy the requirements.txt file into the container at /code
|
8 |
+
COPY ./requirements.txt /code/requirements.txt
|
9 |
+
|
10 |
+
## Install the requirements from requirements.txt
|
11 |
+
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
12 |
+
|
13 |
+
## Set up a new user named "user"d
|
14 |
+
RUN useradd user
|
15 |
+
|
16 |
+
## Switch to the "user" user
|
17 |
+
USER user
|
18 |
+
|
19 |
+
## Set home to the user's home directory
|
20 |
+
ENV HOME=/home/user \
|
21 |
+
PATH=/home/user/.local/bin:$PATH
|
22 |
+
|
23 |
+
## Set the working directory to the user's home directory
|
24 |
+
WORKDIR $HOME/app
|
25 |
+
|
26 |
+
## Copy the current directory contents into the container at $HOME/app and set the owner to "user"
|
27 |
+
COPY --chown=user . $HOME/app
|
28 |
+
|
29 |
+
## Start the FASTAPI app on port 7860
|
30 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
app.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, HTTPException
|
2 |
+
from pydantic import BaseModel
|
3 |
+
from transformers import pipeline
|
4 |
+
import mysql.connector
|
5 |
+
import os
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
|
8 |
+
# Load environment variables from the .env file
|
9 |
+
load_dotenv()
|
10 |
+
|
11 |
+
app = FastAPI()
|
12 |
+
|
13 |
+
# Initialize the text generation pipeline
|
14 |
+
pipe = pipeline("text-generation", model="defog/llama-3-sqlcoder-8b", pad_token_id=2)
|
15 |
+
|
16 |
+
class QueryRequest(BaseModel):
|
17 |
+
text: str
|
18 |
+
|
19 |
+
def get_db_connection():
|
20 |
+
"""Create a new database connection."""
|
21 |
+
try:
|
22 |
+
connection = mysql.connector.connect(
|
23 |
+
host=os.getenv("DB_HOST"),
|
24 |
+
user=os.getenv("DB_USER"),
|
25 |
+
password=os.getenv("DB_PASSWORD"),
|
26 |
+
database=os.getenv("DB_NAME"),
|
27 |
+
raise_on_warnings=True
|
28 |
+
)
|
29 |
+
return connection
|
30 |
+
except mysql.connector.Error as err:
|
31 |
+
print(f"Error: {err}")
|
32 |
+
return None
|
33 |
+
|
34 |
+
def get_database_schema():
|
35 |
+
"""Function to retrieve the database schema dynamically."""
|
36 |
+
schema = {}
|
37 |
+
try:
|
38 |
+
conn = get_db_connection()
|
39 |
+
if conn is None:
|
40 |
+
raise Exception("Failed to connect to the database.")
|
41 |
+
|
42 |
+
cursor = conn.cursor()
|
43 |
+
|
44 |
+
# Query to get table names
|
45 |
+
cursor.execute("SHOW TABLES")
|
46 |
+
tables = cursor.fetchall()
|
47 |
+
|
48 |
+
for table in tables:
|
49 |
+
table_name = table[0]
|
50 |
+
cursor.execute(f"DESCRIBE {table_name}")
|
51 |
+
columns = cursor.fetchall()
|
52 |
+
schema[table_name] = [column[0] for column in columns]
|
53 |
+
|
54 |
+
cursor.close()
|
55 |
+
conn.close()
|
56 |
+
except mysql.connector.Error as err:
|
57 |
+
print(f"Error: {err}")
|
58 |
+
return {}
|
59 |
+
except Exception as e:
|
60 |
+
print(f"An error occurred: {e}")
|
61 |
+
return {}
|
62 |
+
|
63 |
+
return schema
|
64 |
+
|
65 |
+
@app.get("/")
|
66 |
+
def home():
|
67 |
+
return {"message": "SQL Generation Server is running"}
|
68 |
+
|
69 |
+
@app.post("/generate")
|
70 |
+
def generate(request: QueryRequest):
|
71 |
+
try:
|
72 |
+
text = request.text
|
73 |
+
|
74 |
+
# Fetch the database schema
|
75 |
+
schema = get_database_schema()
|
76 |
+
schema_str = "\n".join([f"{table}: {', '.join(columns)}" for table, columns in schema.items()])
|
77 |
+
|
78 |
+
# Construct the system message
|
79 |
+
system_message = f"""
|
80 |
+
You are a helpful, cheerful database assistant.
|
81 |
+
Use the following dynamically retrieved database schema when creating your answers:
|
82 |
+
|
83 |
+
{schema_str}
|
84 |
+
|
85 |
+
When creating your answers, consider the following:
|
86 |
+
|
87 |
+
1. If a query involves a column or value that is not present in the provided database schema, correct it and mention the correction in the summary. If a column or value is missing, provide an explanation of the issue and adjust the query accordingly.
|
88 |
+
2. If there is a spelling mistake in the column name or value, attempt to correct it by matching the closest possible column or value from the schema. Mention this correction in the summary to clarify any changes made.
|
89 |
+
3. Ensure that the correct columns and values are used based on the schema provided. Verify the query against the schema to confirm accuracy.
|
90 |
+
4. Include column name headers in the query results for clarity.
|
91 |
+
|
92 |
+
Always provide your answer in the JSON format below:
|
93 |
+
|
94 |
+
{{ "summary": "your-summary", "query": "your-query" }}
|
95 |
+
|
96 |
+
Output ONLY JSON.
|
97 |
+
In the preceding JSON response, substitute "your-query" with a MariaDB query to retrieve the requested data.
|
98 |
+
In the preceding JSON response, substitute "your-summary" with a summary of the query and any corrections or clarifications made.
|
99 |
+
Always include all columns in the table.
|
100 |
+
"""
|
101 |
+
|
102 |
+
prompt = f"{system_message}\n\nUser request:\n\n{text}\n\nSQL query:"
|
103 |
+
output = pipe(prompt, max_new_tokens=100)
|
104 |
+
|
105 |
+
generated_text = output[0]['generated_text']
|
106 |
+
sql_query = generated_text.split("SQL query:")[-1].strip()
|
107 |
+
|
108 |
+
# Basic validation
|
109 |
+
if not sql_query.lower().startswith(('select', 'show', 'describe')):
|
110 |
+
raise ValueError("Generated text is not a valid SQL query")
|
111 |
+
|
112 |
+
return {"output": sql_query}
|
113 |
+
except Exception as e:
|
114 |
+
raise HTTPException(status_code=500, detail=str(e))
|
115 |
+
|
116 |
+
if __name__ == "__main__":
|
117 |
+
import uvicorn
|
118 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
requests==2.27.*
|
2 |
+
uvicorn[standard]==0.17.*
|
3 |
+
sentencepiece==0.1.*
|
4 |
+
torch==1.13.1
|
5 |
+
numpy<2.0.0
|
6 |
+
fastapi==0.74.*
|
7 |
+
transformers==4.*
|