Spaces:
Sleeping
Sleeping
Upload app.py
Browse files
app.py
CHANGED
@@ -1,22 +1,11 @@
|
|
1 |
-
import
|
2 |
-
from fastapi import FastAPI, HTTPException, Request
|
3 |
-
from fastapi.responses import JSONResponse
|
4 |
from pydantic import BaseModel
|
5 |
from transformers import pipeline
|
6 |
|
7 |
-
# Configure logging
|
8 |
-
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
9 |
-
logger = logging.getLogger(__name__)
|
10 |
-
|
11 |
app = FastAPI()
|
12 |
|
13 |
# Initialize the text generation pipeline
|
14 |
-
|
15 |
-
pipe = pipeline("text-generation", model="defog/sqlcoder-7b-2", pad_token_id=2)
|
16 |
-
logger.info("Model loaded successfully")
|
17 |
-
except Exception as e:
|
18 |
-
logger.error(f"Failed to load the model: {str(e)}")
|
19 |
-
raise
|
20 |
|
21 |
class QueryRequest(BaseModel):
|
22 |
text: str
|
@@ -24,19 +13,15 @@ class QueryRequest(BaseModel):
|
|
24 |
@app.get("/")
|
25 |
def home():
|
26 |
return {"message": "SQL Generation Server is running"}
|
27 |
-
|
28 |
@app.post("/generate")
|
29 |
-
|
30 |
try:
|
31 |
text = request.text
|
32 |
-
logger.info(f"Received request: {text}")
|
33 |
-
|
34 |
prompt = f"Generate a valid SQL query for the following request. Only return the SQL query, nothing else:\n\n{text}\n\nSQL query:"
|
35 |
output = pipe(prompt, max_new_tokens=100)
|
36 |
|
37 |
generated_text = output[0]['generated_text']
|
38 |
sql_query = generated_text.split("SQL query:")[-1].strip()
|
39 |
-
|
40 |
# Basic validation to ensure it's a valid SQL query
|
41 |
if not sql_query.lower().startswith(('select', 'show', 'describe', 'insert', 'update', 'delete')):
|
42 |
raise ValueError("Generated text is not a valid SQL query")
|
@@ -48,28 +33,17 @@ async def generate(request: QueryRequest):
|
|
48 |
allowed_keywords = {
|
49 |
'select', 'insert', 'update', 'delete', 'show', 'describe', 'from', 'where', 'and', 'or', 'like', 'limit', 'order by', 'group by', 'join', 'inner join', 'left join', 'right join', 'full join', 'on', 'using', 'union', 'union all', 'distinct', 'having', 'into', 'values', 'set', 'create', 'alter', 'drop', 'table', 'database', 'index', 'view', 'trigger', 'procedure', 'function', 'if', 'exists', 'primary key', 'foreign key', 'references', 'check', 'constraint', 'default', 'auto_increment', 'null', 'not null', 'in', 'is', 'is not', 'between', 'case', 'when', 'then', 'else', 'end', 'asc', 'desc', 'count', 'sum', 'avg', 'min', 'max', 'timestamp', 'date', 'time', 'varchar', 'char', 'int', 'integer', 'smallint', 'bigint', 'decimal', 'numeric', 'float', 'real', 'double', 'boolean', 'enum', 'text', 'blob', 'clob'
|
50 |
}
|
51 |
-
|
52 |
# Ensure the query only contains allowed keywords
|
53 |
tokens = sql_query.lower().split()
|
54 |
for token in tokens:
|
55 |
if not any(token.startswith(keyword) for keyword in allowed_keywords):
|
56 |
-
raise ValueError(
|
57 |
|
58 |
-
logger.info(f"Generated SQL query: {sql_query}")
|
59 |
return {"output": sql_query}
|
60 |
-
except ValueError as ve:
|
61 |
-
logger.warning(f"Validation error: {str(ve)}")
|
62 |
-
raise HTTPException(status_code=400, detail=str(ve))
|
63 |
except Exception as e:
|
64 |
-
|
65 |
-
|
66 |
|
67 |
-
@app.exception_handler(HTTPException)
|
68 |
-
async def http_exception_handler(request: Request, exc: HTTPException):
|
69 |
-
return JSONResponse(
|
70 |
-
status_code=exc.status_code,
|
71 |
-
content={"message": exc.detail},
|
72 |
-
)
|
73 |
|
74 |
if __name__ == "__main__":
|
75 |
import uvicorn
|
|
|
1 |
+
from fastapi import FastAPI, HTTPException
|
|
|
|
|
2 |
from pydantic import BaseModel
|
3 |
from transformers import pipeline
|
4 |
|
|
|
|
|
|
|
|
|
5 |
app = FastAPI()
|
6 |
|
7 |
# Initialize the text generation pipeline
|
8 |
+
pipe = pipeline("text-generation", model="defog/sqlcoder-7b-2", pad_token_id=2)
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
class QueryRequest(BaseModel):
|
11 |
text: str
|
|
|
13 |
@app.get("/")
|
14 |
def home():
|
15 |
return {"message": "SQL Generation Server is running"}
|
|
|
16 |
@app.post("/generate")
|
17 |
+
def generate(request: QueryRequest):
|
18 |
try:
|
19 |
text = request.text
|
|
|
|
|
20 |
prompt = f"Generate a valid SQL query for the following request. Only return the SQL query, nothing else:\n\n{text}\n\nSQL query:"
|
21 |
output = pipe(prompt, max_new_tokens=100)
|
22 |
|
23 |
generated_text = output[0]['generated_text']
|
24 |
sql_query = generated_text.split("SQL query:")[-1].strip()
|
|
|
25 |
# Basic validation to ensure it's a valid SQL query
|
26 |
if not sql_query.lower().startswith(('select', 'show', 'describe', 'insert', 'update', 'delete')):
|
27 |
raise ValueError("Generated text is not a valid SQL query")
|
|
|
33 |
allowed_keywords = {
|
34 |
'select', 'insert', 'update', 'delete', 'show', 'describe', 'from', 'where', 'and', 'or', 'like', 'limit', 'order by', 'group by', 'join', 'inner join', 'left join', 'right join', 'full join', 'on', 'using', 'union', 'union all', 'distinct', 'having', 'into', 'values', 'set', 'create', 'alter', 'drop', 'table', 'database', 'index', 'view', 'trigger', 'procedure', 'function', 'if', 'exists', 'primary key', 'foreign key', 'references', 'check', 'constraint', 'default', 'auto_increment', 'null', 'not null', 'in', 'is', 'is not', 'between', 'case', 'when', 'then', 'else', 'end', 'asc', 'desc', 'count', 'sum', 'avg', 'min', 'max', 'timestamp', 'date', 'time', 'varchar', 'char', 'int', 'integer', 'smallint', 'bigint', 'decimal', 'numeric', 'float', 'real', 'double', 'boolean', 'enum', 'text', 'blob', 'clob'
|
35 |
}
|
|
|
36 |
# Ensure the query only contains allowed keywords
|
37 |
tokens = sql_query.lower().split()
|
38 |
for token in tokens:
|
39 |
if not any(token.startswith(keyword) for keyword in allowed_keywords):
|
40 |
+
raise ValueError("Generated text contains invalid SQL syntax")
|
41 |
|
|
|
42 |
return {"output": sql_query}
|
|
|
|
|
|
|
43 |
except Exception as e:
|
44 |
+
raise HTTPException(status_code=500, detail=str(e))
|
45 |
+
|
46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
if __name__ == "__main__":
|
49 |
import uvicorn
|