barathm111 commited on
Commit
3de53f6
1 Parent(s): dd0bda0

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -14
app.py CHANGED
@@ -14,7 +14,7 @@ app = FastAPI()
14
  pipe = pipeline("text-generation", model="defog/llama-3-sqlcoder-8b", pad_token_id=2)
15
 
16
  class QueryRequest(BaseModel):
17
- text: str
18
 
19
  def get_db_connection():
20
  """Create a new database connection."""
@@ -66,13 +66,10 @@ def get_database_schema():
66
  def home():
67
  return {"message": "SQL Generation Server is running"}
68
 
69
- @app.post("/generate")
70
- def generate(request: QueryRequest):
71
  try:
72
- # Log the incoming request text for debugging
73
- print(f"Received request with text: {request.text}")
74
-
75
- text = request.text
76
 
77
  # Fetch the database schema
78
  schema = get_database_schema()
@@ -80,13 +77,27 @@ def generate(request: QueryRequest):
80
 
81
  # Construct the system message
82
  system_message = f"""
83
- You are a helpful, cheerful database assistant.
84
- Use the following dynamically retrieved database schema when creating your answers:
 
 
 
 
85
 
86
- {schema_str}
 
 
 
87
 
88
- [Additional instructions as in your original code]
89
- """
 
 
 
 
 
 
 
90
 
91
  prompt = f"{system_message}\n\nUser request:\n\n{text}\n\nSQL query:"
92
  output = pipe(prompt, max_new_tokens=100)
@@ -98,11 +109,19 @@ def generate(request: QueryRequest):
98
  if not sql_query.lower().startswith(('select', 'show', 'describe')):
99
  raise ValueError("Generated text is not a valid SQL query")
100
 
101
- return {"output": sql_query}
 
 
 
 
 
 
 
 
 
102
  except Exception as e:
103
  raise HTTPException(status_code=500, detail=str(e))
104
 
105
-
106
  if __name__ == "__main__":
107
  import uvicorn
108
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
14
  pipe = pipeline("text-generation", model="defog/llama-3-sqlcoder-8b", pad_token_id=2)
15
 
16
  class QueryRequest(BaseModel):
17
+ query: str
18
 
19
  def get_db_connection():
20
  """Create a new database connection."""
 
66
  def home():
67
  return {"message": "SQL Generation Server is running"}
68
 
69
+ @app.post("/query")
70
+ def handle_query(request: QueryRequest):
71
  try:
72
+ text = request.query
 
 
 
73
 
74
  # Fetch the database schema
75
  schema = get_database_schema()
 
77
 
78
  # Construct the system message
79
  system_message = f"""
80
+ You are a helpful, cheerful database assistant.
81
+ Use the following dynamically retrieved database schema when creating your answers:
82
+
83
+ {schema_str}
84
+
85
+ When creating your answers, consider the following:
86
 
87
+ 1. If a query involves a column or value that is not present in the provided database schema, correct it and mention the correction in the summary. If a column or value is missing, provide an explanation of the issue and adjust the query accordingly.
88
+ 2. If there is a spelling mistake in the column name or value, attempt to correct it by matching the closest possible column or value from the schema. Mention this correction in the summary to clarify any changes made.
89
+ 3. Ensure that the correct columns and values are used based on the schema provided. Verify the query against the schema to confirm accuracy.
90
+ 4. Include column name headers in the query results for clarity.
91
 
92
+ Always provide your answer in the JSON format below:
93
+
94
+ {{ "summary": "your-summary", "query": "your-query" }}
95
+
96
+ Output ONLY JSON.
97
+ In the preceding JSON response, substitute "your-query" with a MariaDB query to retrieve the requested data.
98
+ In the preceding JSON response, substitute "your-summary" with a summary of the query and any corrections or clarifications made.
99
+ Always include all columns in the table.
100
+ """
101
 
102
  prompt = f"{system_message}\n\nUser request:\n\n{text}\n\nSQL query:"
103
  output = pipe(prompt, max_new_tokens=100)
 
109
  if not sql_query.lower().startswith(('select', 'show', 'describe')):
110
  raise ValueError("Generated text is not a valid SQL query")
111
 
112
+ # Example: execute the generated SQL query and return the results
113
+ conn = get_db_connection()
114
+ cursor = conn.cursor()
115
+ cursor.execute(sql_query)
116
+ results = cursor.fetchall()
117
+
118
+ cursor.close()
119
+ conn.close()
120
+
121
+ return {"sql": sql_query, "results": results}
122
  except Exception as e:
123
  raise HTTPException(status_code=500, detail=str(e))
124
 
 
125
  if __name__ == "__main__":
126
  import uvicorn
127
  uvicorn.run(app, host="0.0.0.0", port=7860)