knowledge_model / txt2sql.py
arjunanand13's picture
Upload 10 files
1d35b34 verified
raw
history blame
8.31 kB
import sqlite3
from sqlite3 import Error
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer, BitsAndBytesConfig
from transformers import AutoModelForCausalLM
from openai import OpenAI
import google.generativeai as genai
class SQLPromptModel:
def __init__(self, model_dir, database):
self.model_dir = model_dir
self.database = database
# peft_model_dir = self.model_dir
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype="float16",
bnb_4bit_use_double_quant=True,
)
# self.model = AutoPeftModelForCausalLM.from_pretrained(
# peft_model_dir, low_cpu_mem_usage=True, quantization_config=bnb_config
# )
# self.tokenizer = AutoTokenizer.from_pretrained(peft_model_dir)
# self.model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
# self.tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
self.chatgpt_client = OpenAI(api_key="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")
self.genai = genai
self.genai.configure(api_key="AIzaSyAFG94rVbm9eWepO5uPGsMha8XJ-sHbMdA")
self.genai_model = genai.GenerativeModel('gemini-pro')
self.conn = sqlite3.connect(self.database)
def fetch_table_schema(self, table_name):
"""Fetch the schema of a table from the database."""
cursor = self.conn.cursor()
cursor.execute(f"PRAGMA table_info({table_name})")
schema = cursor.fetchall()
if schema:
return schema
else:
print(f"Table {table_name} does not exist or has no schema.")
return None
def text2sql(self, schema, user_prompt, inp_prompt=None):
"""Generate SQL query based on user prompt and table schema.inp_prompt is for gradio purpose"""
table_columns = ', '.join([f"{col[1]} {col[2]}" for col in schema])
prompt = f"""Below are SQL table schemas paired with instructions that describe a task.
Using valid SQLite, write a response that appropriately completes the request for the provided tables.
Select all columns unless specified in specific.
Example row :1,Michael,michael@ignitarium.com,59,Female,Headphones,2023-01-03,General inquiry,Email,44 hours,88 hours,4,Technical Issue,Server crashes due to memory leaks in custom-developed software.,Closed,"Debug and optimize the software code to identify and fix memory leaks, and implement regular monitoring for early detection.",Medium
### Instruction: {user_prompt} ###
Input: CREATE TABLE ticket_dataset({table_columns});
### Response: (Return only query , nothing extra)"""
if inp_prompt is not None :
prompt = prompt.replace(user_prompt, inp_prompt + " ")
else:
inp_prompt = input("Press Enter for default question or Enter user prompt without newline characters: ").strip()
if inp_prompt:
prompt = prompt.replace(user_prompt, inp_prompt + " ")
"""Text to SQL query generation"""
input_ids = self.tokenizer(
prompt, return_tensors="pt", truncation=True
).input_ids.to(next(self.model.parameters()).device) # Move input to the device of the model
outputs = self.model.generate(input_ids=input_ids, max_new_tokens=200)
response = self.tokenizer.batch_decode(
outputs.detach().cpu().numpy(), skip_special_tokens=True
)[0][:]
return response[len(prompt):]
def text2sql_chatgpt(self, schema, user_prompt, inp_prompt=None):
table_columns = ', '.join([f"{col[1]} {col[2]}" for col in schema])
prompt = f"""Below are SQL table schemas paired with instructions that describe a task.
Using valid SQLite, write a response that appropriately completes the request for the provided tables.
Select all columns unless specified in specific.
Example row :1,Michael,michael@ignitarium.com,59,Female,Headphones,2023-01-03,General inquiry,Email,44 hours,88 hours,4,Technical Issue,Server crashes due to memory leaks in custom-developed software.,Closed,"Debug and optimize the software code to identify and fix memory leaks, and implement regular monitoring for early detection.",Medium
### Instruction: {user_prompt} ###
Input: CREATE TABLE ticket_dataset({table_columns});
### Response: (Return only generated query based on user_prompt , nothing extra)"""
if inp_prompt is not None :
prompt = prompt.replace(user_prompt, inp_prompt + " ")
else:
inp_prompt = input("Press Enter for default question or Enter user prompt without newline characters: ").strip()
if inp_prompt:
prompt = prompt.replace(user_prompt, inp_prompt + " ")
print(prompt)
completion = self.chatgpt_client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "You are a expert SQL developer , generate a sql query and return it"},
{"role": "user", "content": prompt }
]
)
return completion.choices[0].message.content
def text2sql_gemini(self, schema, user_prompt, inp_prompt=None):
table_columns = ', '.join([f"{col[1]} {col[2]}" for col in schema])
prompt = f"""Below are SQL table schemas paired with instructions that describe a task.
Using valid SQLite, write a response that appropriately completes the request for the provided tables.
Select all columns unless specified in specific.
Example row :1,Michael,michael@ignitarium.com,59,Female,Headphones,2023-01-03,General inquiry,Email,44 hours,88 hours,4,Technical Issue,Server crashes due to memory leaks in custom-developed software.,Closed,"Debug and optimize the software code to identify and fix memory leaks, and implement regular monitoring for early detection.",Medium
### Instruction: {user_prompt} ###
Input: CREATE TABLE ticket_dataset({table_columns});
### Response: (Return only generated query based on user_prompt , nothing extra)"""
if inp_prompt is not None :
prompt = prompt.replace(user_prompt, inp_prompt + " ")
else:
inp_prompt = input("Press Enter for default question or Enter user prompt without newline characters: ").strip()
if inp_prompt:
prompt = prompt.replace(user_prompt, inp_prompt + " ")
print(prompt)
completion = self.genai_model.generate_content(prompt)
generated_query=completion.text
start_index = generated_query.find("SELECT")
end_index = generated_query.find(";", start_index) + 1
print(start_index,end_index)
if start_index != -1 and end_index != 0:
return generated_query[start_index:end_index]
else:
return generated_query
def execute_query(self, query):
"""Executing the query on database and returning rows and columns."""
print(query)
cur = self.conn.cursor()
cur.execute(query)
col = [header[0] for header in cur.description]
dash = "-" * sum(len(col_name) + 4 for col_name in col)
print(tuple(col))
print(dash)
rows = []
for member in cur:
rows.append(member)
print(member)
cur.close()
self.conn.commit()
# print(rows)
return rows, col
if __name__ == "__main__":
model_dir = "multi_table_demo/checkpoint-2600"
database = r"ticket_dataset.db"
sql_model = SQLPromptModel(model_dir, database)
user_prompt = "Give complete details of properties in India"
while True:
table_schema = sql_model.fetch_table_schema("ticket_dataset")
if table_schema:
# query = sql_model.text2sql(table_schema, user_prompt)
# query = sql_model.text2sql_chatgpt(table_schema, user_prompt)
query = sql_model.text2sql_gemini(table_schema, user_prompt)
print(query)
sql_model.execute_query(query)
sql_model.conn.close()