Spaces:
Sleeping
Sleeping
File size: 7,056 Bytes
7de61f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import sqlite3
from sqlite3 import Error
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer, BitsAndBytesConfig
from transformers import AutoModelForCausalLM
from openai import OpenAI
import google.generativeai as genai
class SQLPromptModel:
def __init__(self, model_dir, database):
self.model_dir = model_dir
self.database = database
# peft_model_dir = self.model_dir
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype="float16",
bnb_4bit_use_double_quant=True,
)
# self.model = AutoPeftModelForCausalLM.from_pretrained(
# peft_model_dir, low_cpu_mem_usage=True, quantization_config=bnb_config
# )
# self.tokenizer = AutoTokenizer.from_pretrained(peft_model_dir)
# self.model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
# self.tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
self.chatgpt_client = OpenAI(api_key="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")
self.genai = genai
self.genai.configure(api_key="AIzaSyAFG94rVbm9eWepO5uPGsMha8XJ-sHbMdA")
self.genai_model = genai.GenerativeModel('gemini-pro')
self.conn = sqlite3.connect(self.database)
def fetch_table_schema(self, table_name):
"""Fetch the schema of a table from the database."""
cursor = self.conn.cursor()
cursor.execute(f"PRAGMA table_info({table_name})")
schema = cursor.fetchall()
if schema:
return schema
else:
print(f"Table {table_name} does not exist or has no schema.")
return None
def text2sql(self, schema, user_prompt, inp_prompt=None):
"""Generate SQL query based on user prompt and table schema.inp_prompt is for gradio purpose"""
table_columns = ', '.join([f"{col[1]} {col[2]}" for col in schema])
prompt = f"""Below are SQL table schemas paired with instructions that describe a task.
Using valid SQLite, write a response that appropriately completes the request for the provided tables.
### Instruction: {user_prompt} ###
Input: CREATE TABLE sql_pdf({table_columns});
### Response: (Return only query , nothing extra)"""
if inp_prompt is not None :
prompt = prompt.replace(user_prompt, inp_prompt + " ")
else:
inp_prompt = input("Press Enter for default question or Enter user prompt without newline characters: ").strip()
if inp_prompt:
prompt = prompt.replace(user_prompt, inp_prompt + " ")
"""Text to SQL query generation"""
input_ids = self.tokenizer(
prompt, return_tensors="pt", truncation=True
).input_ids.to(next(self.model.parameters()).device) # Move input to the device of the model
outputs = self.model.generate(input_ids=input_ids, max_new_tokens=200)
response = self.tokenizer.batch_decode(
outputs.detach().cpu().numpy(), skip_special_tokens=True
)[0][:]
return response[len(prompt):]
def text2sql_chatgpt(self, schema, user_prompt, inp_prompt=None):
table_columns = ', '.join([f"{col[1]} {col[2]}" for col in schema])
prompt = f"""Below are SQL table schemas paired with instructions that describe a task.
Using valid SQLite, write a response that appropriately completes the request for the provided tables.
### Instruction: {user_prompt} ###
Input: CREATE TABLE sql_pdf({table_columns});
### Response: (Return only generated query based on user_prompt , nothing extra)"""
if inp_prompt is not None :
prompt = prompt.replace(user_prompt, inp_prompt + " ")
else:
inp_prompt = input("Press Enter for default question or Enter user prompt without newline characters: ").strip()
if inp_prompt:
prompt = prompt.replace(user_prompt, inp_prompt + " ")
print(prompt)
completion = self.chatgpt_client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "You are a expert SQL developer , generate a sql query and return it"},
{"role": "user", "content": prompt }
]
)
return completion.choices[0].message.content
def text2sql_gemini(self, schema, user_prompt, inp_prompt=None):
table_columns = ', '.join([f"{col[1]} {col[2]}" for col in schema])
prompt = f"""Below are SQL table schemas paired with instructions that describe a task.
Using valid SQLite, write a response that appropriately completes the request for the provided tables.
### Instruction: {user_prompt} ###
Input: CREATE TABLE sql_pdf({table_columns});
### Response: (Return only generated query based on user_prompt , nothing extra)"""
if inp_prompt is not None :
prompt = prompt.replace(user_prompt, inp_prompt + " ")
else:
inp_prompt = input("Press Enter for default question or Enter user prompt without newline characters: ").strip()
if inp_prompt:
prompt = prompt.replace(user_prompt, inp_prompt + " ")
print(prompt)
completion = self.genai_model.generate_content(prompt)
generated_query=completion.text
start_index = generated_query.find("SELECT")
end_index = generated_query.find(";", start_index) + 1
print(start_index,end_index)
if start_index != -1 and end_index != 0:
return generated_query[start_index:end_index]
else:
return generated_query
def execute_query(self, query):
"""Executing the query on database and returning rows and columns."""
print(query)
cur = self.conn.cursor()
cur.execute(query)
col = [header[0] for header in cur.description]
dash = "-" * sum(len(col_name) + 4 for col_name in col)
print(tuple(col))
print(dash)
rows = []
for member in cur:
rows.append(member)
print(member)
cur.close()
self.conn.commit()
# print(rows)
return rows, col
if __name__ == "__main__":
model_dir = "multi_table_demo/checkpoint-2600"
database = r"sql_pdf.db"
sql_model = SQLPromptModel(model_dir, database)
user_prompt = "Give complete details of properties in India"
while True:
table_schema = sql_model.fetch_table_schema("sql_pdf")
if table_schema:
# query = sql_model.text2sql(table_schema, user_prompt)
# query = sql_model.text2sql_chatgpt(table_schema, user_prompt)
query = sql_model.text2sql_gemini(table_schema, user_prompt)
print(query)
sql_model.execute_query(query)
sql_model.conn.close()
|