import sqlite3 from sqlite3 import Error from peft import AutoPeftModelForCausalLM from transformers import AutoTokenizer, BitsAndBytesConfig from transformers import AutoModelForCausalLM from openai import OpenAI import google.generativeai as genai class SQLPromptModel: def __init__(self, model_dir, database): self.model_dir = model_dir self.database = database # peft_model_dir = self.model_dir bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype="float16", bnb_4bit_use_double_quant=True, ) # self.model = AutoPeftModelForCausalLM.from_pretrained( # peft_model_dir, low_cpu_mem_usage=True, quantization_config=bnb_config # ) # self.tokenizer = AutoTokenizer.from_pretrained(peft_model_dir) # self.model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2") # self.tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2") self.chatgpt_client = OpenAI(api_key="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") self.genai = genai self.genai.configure(api_key="AIzaSyAFG94rVbm9eWepO5uPGsMha8XJ-sHbMdA") self.genai_model = genai.GenerativeModel('gemini-pro') self.conn = sqlite3.connect(self.database) def fetch_table_schema(self, table_name): """Fetch the schema of a table from the database.""" cursor = self.conn.cursor() cursor.execute(f"PRAGMA table_info({table_name})") schema = cursor.fetchall() if schema: return schema else: print(f"Table {table_name} does not exist or has no schema.") return None def text2sql(self, schema, user_prompt, inp_prompt=None): """Generate SQL query based on user prompt and table schema.inp_prompt is for gradio purpose""" table_columns = ', '.join([f"{col[1]} {col[2]}" for col in schema]) prompt = f"""Below are SQL table schemas paired with instructions that describe a task. Using valid SQLite, write a response that appropriately completes the request for the provided tables. Select all columns unless specified in specific. Example row :1,Michael,,59,Female,Headphones,2023-01-03,General inquiry,Email,44 hours,88 hours,4,Technical Issue,Server crashes due to memory leaks in custom-developed software.,Closed,"Debug and optimize the software code to identify and fix memory leaks, and implement regular monitoring for early detection.",Medium ### Instruction: {user_prompt} ### Input: CREATE TABLE ticket_dataset({table_columns}); ### Response: (Return only query , nothing extra)""" if inp_prompt is not None : prompt = prompt.replace(user_prompt, inp_prompt + " ") else: inp_prompt = input("Press Enter for default question or Enter user prompt without newline characters: ").strip() if inp_prompt: prompt = prompt.replace(user_prompt, inp_prompt + " ") """Text to SQL query generation""" input_ids = self.tokenizer( prompt, return_tensors="pt", truncation=True ) # Move input to the device of the model outputs = self.model.generate(input_ids=input_ids, max_new_tokens=200) response = self.tokenizer.batch_decode( outputs.detach().cpu().numpy(), skip_special_tokens=True )[0][:] return response[len(prompt):] def text2sql_chatgpt(self, schema, user_prompt, inp_prompt=None): table_columns = ', '.join([f"{col[1]} {col[2]}" for col in schema]) prompt = f"""Below are SQL table schemas paired with instructions that describe a task. Using valid SQLite, write a response that appropriately completes the request for the provided tables. Select all columns unless specified in specific. Example row :1,Michael,,59,Female,Headphones,2023-01-03,General inquiry,Email,44 hours,88 hours,4,Technical Issue,Server crashes due to memory leaks in custom-developed software.,Closed,"Debug and optimize the software code to identify and fix memory leaks, and implement regular monitoring for early detection.",Medium ### Instruction: {user_prompt} ### Input: CREATE TABLE ticket_dataset({table_columns}); ### Response: (Return only generated query based on user_prompt , nothing extra)""" if inp_prompt is not None : prompt = prompt.replace(user_prompt, inp_prompt + " ") else: inp_prompt = input("Press Enter for default question or Enter user prompt without newline characters: ").strip() if inp_prompt: prompt = prompt.replace(user_prompt, inp_prompt + " ") print(prompt) completion = model="gpt-3.5-turbo", messages=[ {"role": "system", "content": "You are a expert SQL developer , generate a sql query and return it"}, {"role": "user", "content": prompt } ] ) return completion.choices[0].message.content def text2sql_gemini(self, schema, user_prompt, inp_prompt=None): table_columns = ', '.join([f"{col[1]} {col[2]}" for col in schema]) prompt = f"""Below are SQL table schemas paired with instructions that describe a task. Using valid SQLite, write a response that appropriately completes the request for the provided tables. Select all columns unless specified in specific. Example row :1,Michael,,59,Female,Headphones,2023-01-03,General inquiry,Email,44 hours,88 hours,4,Technical Issue,Server crashes due to memory leaks in custom-developed software.,Closed,"Debug and optimize the software code to identify and fix memory leaks, and implement regular monitoring for early detection.",Medium ### Instruction: {user_prompt} ### Input: CREATE TABLE ticket_dataset({table_columns}); ### Response: (Return only generated query based on user_prompt , nothing extra)""" if inp_prompt is not None : prompt = prompt.replace(user_prompt, inp_prompt + " ") else: inp_prompt = input("Press Enter for default question or Enter user prompt without newline characters: ").strip() if inp_prompt: prompt = prompt.replace(user_prompt, inp_prompt + " ") print(prompt) completion = self.genai_model.generate_content(prompt) generated_query=completion.text start_index = generated_query.find("SELECT") end_index = generated_query.find(";", start_index) + 1 print(start_index,end_index) if start_index != -1 and end_index != 0: return generated_query[start_index:end_index] else: return generated_query def execute_query(self, query): """Executing the query on database and returning rows and columns.""" print(query) cur = self.conn.cursor() cur.execute(query) col = [header[0] for header in cur.description] dash = "-" * sum(len(col_name) + 4 for col_name in col) print(tuple(col)) print(dash) rows = [] for member in cur: rows.append(member) print(member) cur.close() self.conn.commit() # print(rows) return rows, col if __name__ == "__main__": model_dir = "multi_table_demo/checkpoint-2600" database = r"ticket_dataset.db" sql_model = SQLPromptModel(model_dir, database) user_prompt = "Give complete details of properties in India" while True: table_schema = sql_model.fetch_table_schema("ticket_dataset") if table_schema: # query = sql_model.text2sql(table_schema, user_prompt) # query = sql_model.text2sql_chatgpt(table_schema, user_prompt) query = sql_model.text2sql_gemini(table_schema, user_prompt) print(query) sql_model.execute_query(query) sql_model.conn.close()