Spaces:
Sleeping
Sleeping
import sqlite3 | |
from sqlite3 import Error | |
from peft import AutoPeftModelForCausalLM | |
from transformers import AutoTokenizer, BitsAndBytesConfig | |
from transformers import AutoModelForCausalLM | |
from openai import OpenAI | |
import google.generativeai as genai | |
class SQLPromptModel: | |
def __init__(self, model_dir, database): | |
self.model_dir = model_dir | |
self.database = database | |
# peft_model_dir = self.model_dir | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype="float16", | |
bnb_4bit_use_double_quant=True, | |
) | |
# self.model = AutoPeftModelForCausalLM.from_pretrained( | |
# peft_model_dir, low_cpu_mem_usage=True, quantization_config=bnb_config | |
# ) | |
# self.tokenizer = AutoTokenizer.from_pretrained(peft_model_dir) | |
# self.model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2") | |
# self.tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2") | |
self.chatgpt_client = OpenAI(api_key="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") | |
self.genai = genai | |
self.genai.configure(api_key="AIzaSyAFG94rVbm9eWepO5uPGsMha8XJ-sHbMdA") | |
self.genai_model = genai.GenerativeModel('gemini-pro') | |
self.conn = sqlite3.connect(self.database) | |
def fetch_table_schema(self, table_name): | |
"""Fetch the schema of a table from the database.""" | |
cursor = self.conn.cursor() | |
cursor.execute(f"PRAGMA table_info({table_name})") | |
schema = cursor.fetchall() | |
if schema: | |
return schema | |
else: | |
print(f"Table {table_name} does not exist or has no schema.") | |
return None | |
def text2sql(self, schema, user_prompt, inp_prompt=None): | |
"""Generate SQL query based on user prompt and table schema.inp_prompt is for gradio purpose""" | |
table_columns = ', '.join([f"{col[1]} {col[2]}" for col in schema]) | |
prompt = f"""Below are SQL table schemas paired with instructions that describe a task. | |
Using valid SQLite, write a response that appropriately completes the request for the provided tables. | |
Select all columns unless specified in specific. | |
Example row :1,Michael,michael@ignitarium.com,59,Female,Headphones,2023-01-03,General inquiry,Email,44 hours,88 hours,4,Technical Issue,Server crashes due to memory leaks in custom-developed software.,Closed,"Debug and optimize the software code to identify and fix memory leaks, and implement regular monitoring for early detection.",Medium | |
### Instruction: {user_prompt} ### | |
Input: CREATE TABLE ticket_dataset({table_columns}); | |
### Response: (Return only query , nothing extra)""" | |
if inp_prompt is not None : | |
prompt = prompt.replace(user_prompt, inp_prompt + " ") | |
else: | |
inp_prompt = input("Press Enter for default question or Enter user prompt without newline characters: ").strip() | |
if inp_prompt: | |
prompt = prompt.replace(user_prompt, inp_prompt + " ") | |
"""Text to SQL query generation""" | |
input_ids = self.tokenizer( | |
prompt, return_tensors="pt", truncation=True | |
).input_ids.to(next(self.model.parameters()).device) # Move input to the device of the model | |
outputs = self.model.generate(input_ids=input_ids, max_new_tokens=200) | |
response = self.tokenizer.batch_decode( | |
outputs.detach().cpu().numpy(), skip_special_tokens=True | |
)[0][:] | |
return response[len(prompt):] | |
def text2sql_chatgpt(self, schema, user_prompt, inp_prompt=None): | |
table_columns = ', '.join([f"{col[1]} {col[2]}" for col in schema]) | |
prompt = f"""Below are SQL table schemas paired with instructions that describe a task. | |
Using valid SQLite, write a response that appropriately completes the request for the provided tables. | |
Select all columns unless specified in specific. | |
Example row :1,Michael,michael@ignitarium.com,59,Female,Headphones,2023-01-03,General inquiry,Email,44 hours,88 hours,4,Technical Issue,Server crashes due to memory leaks in custom-developed software.,Closed,"Debug and optimize the software code to identify and fix memory leaks, and implement regular monitoring for early detection.",Medium | |
### Instruction: {user_prompt} ### | |
Input: CREATE TABLE ticket_dataset({table_columns}); | |
### Response: (Return only generated query based on user_prompt , nothing extra)""" | |
if inp_prompt is not None : | |
prompt = prompt.replace(user_prompt, inp_prompt + " ") | |
else: | |
inp_prompt = input("Press Enter for default question or Enter user prompt without newline characters: ").strip() | |
if inp_prompt: | |
prompt = prompt.replace(user_prompt, inp_prompt + " ") | |
print(prompt) | |
completion = self.chatgpt_client.chat.completions.create( | |
model="gpt-3.5-turbo", | |
messages=[ | |
{"role": "system", "content": "You are a expert SQL developer , generate a sql query and return it"}, | |
{"role": "user", "content": prompt } | |
] | |
) | |
return completion.choices[0].message.content | |
def text2sql_gemini(self, schema, user_prompt, inp_prompt=None): | |
table_columns = ', '.join([f"{col[1]} {col[2]}" for col in schema]) | |
prompt = f"""Below are SQL table schemas paired with instructions that describe a task. | |
Using valid SQLite, write a response that appropriately completes the request for the provided tables. | |
Select all columns unless specified in specific. | |
Example row :1,Michael,michael@ignitarium.com,59,Female,Headphones,2023-01-03,General inquiry,Email,44 hours,88 hours,4,Technical Issue,Server crashes due to memory leaks in custom-developed software.,Closed,"Debug and optimize the software code to identify and fix memory leaks, and implement regular monitoring for early detection.",Medium | |
### Instruction: {user_prompt} ### | |
Input: CREATE TABLE ticket_dataset({table_columns}); | |
### Response: (Return only generated query based on user_prompt , nothing extra)""" | |
if inp_prompt is not None : | |
prompt = prompt.replace(user_prompt, inp_prompt + " ") | |
else: | |
inp_prompt = input("Press Enter for default question or Enter user prompt without newline characters: ").strip() | |
if inp_prompt: | |
prompt = prompt.replace(user_prompt, inp_prompt + " ") | |
print(prompt) | |
completion = self.genai_model.generate_content(prompt) | |
generated_query=completion.text | |
start_index = generated_query.find("SELECT") | |
end_index = generated_query.find(";", start_index) + 1 | |
print(start_index,end_index) | |
if start_index != -1 and end_index != 0: | |
return generated_query[start_index:end_index] | |
else: | |
return generated_query | |
def execute_query(self, query): | |
"""Executing the query on database and returning rows and columns.""" | |
print(query) | |
cur = self.conn.cursor() | |
cur.execute(query) | |
col = [header[0] for header in cur.description] | |
dash = "-" * sum(len(col_name) + 4 for col_name in col) | |
print(tuple(col)) | |
print(dash) | |
rows = [] | |
for member in cur: | |
rows.append(member) | |
print(member) | |
cur.close() | |
self.conn.commit() | |
# print(rows) | |
return rows, col | |
if __name__ == "__main__": | |
model_dir = "multi_table_demo/checkpoint-2600" | |
database = r"ticket_dataset.db" | |
sql_model = SQLPromptModel(model_dir, database) | |
user_prompt = "Give complete details of properties in India" | |
while True: | |
table_schema = sql_model.fetch_table_schema("ticket_dataset") | |
if table_schema: | |
# query = sql_model.text2sql(table_schema, user_prompt) | |
# query = sql_model.text2sql_chatgpt(table_schema, user_prompt) | |
query = sql_model.text2sql_gemini(table_schema, user_prompt) | |
print(query) | |
sql_model.execute_query(query) | |
sql_model.conn.close() | |