File size: 8,310 Bytes
1d35b34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import sqlite3
from sqlite3 import Error
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer, BitsAndBytesConfig
from transformers import AutoModelForCausalLM
from openai import OpenAI
import google.generativeai as genai

class SQLPromptModel:
    def __init__(self, model_dir, database):
        self.model_dir = model_dir
        self.database = database
        # peft_model_dir = self.model_dir
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype="float16",
            bnb_4bit_use_double_quant=True,
        )
        # self.model = AutoPeftModelForCausalLM.from_pretrained(
        #     peft_model_dir, low_cpu_mem_usage=True, quantization_config=bnb_config
        # )
        # self.tokenizer = AutoTokenizer.from_pretrained(peft_model_dir)
        # self.model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
        # self.tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
        self.chatgpt_client = OpenAI(api_key="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")
        self.genai = genai
        self.genai.configure(api_key="AIzaSyAFG94rVbm9eWepO5uPGsMha8XJ-sHbMdA")
        self.genai_model = genai.GenerativeModel('gemini-pro')

        self.conn = sqlite3.connect(self.database)

    def fetch_table_schema(self, table_name):
        """Fetch the schema of a table from the database."""
        cursor = self.conn.cursor()
        cursor.execute(f"PRAGMA table_info({table_name})")
        schema = cursor.fetchall()
        if schema:
            return schema
        else:
            print(f"Table {table_name} does not exist or has no schema.")
            return None
        
    def text2sql(self, schema, user_prompt, inp_prompt=None):
        """Generate SQL query based on user prompt and table schema.inp_prompt is for gradio purpose"""
        table_columns = ', '.join([f"{col[1]} {col[2]}" for col in schema])

        prompt = f"""Below are SQL table schemas paired with instructions that describe a task.
        Using valid SQLite, write a response that appropriately completes the request for the provided tables.
        Select all columns unless specified in specific.
        Example row :1,Michael,michael@ignitarium.com,59,Female,Headphones,2023-01-03,General inquiry,Email,44 hours,88 hours,4,Technical Issue,Server crashes due to memory leaks in custom-developed software.,Closed,"Debug and optimize the software code to identify and fix memory leaks, and implement regular monitoring for early detection.",Medium
        ### Instruction: {user_prompt} ### 
        Input: CREATE TABLE ticket_dataset({table_columns});
        ### Response: (Return only query , nothing extra)"""

        if inp_prompt is not None :
            prompt = prompt.replace(user_prompt, inp_prompt + " ")
        else:
            inp_prompt = input("Press Enter for default question or Enter user prompt without newline characters: ").strip()
            if inp_prompt:
                prompt = prompt.replace(user_prompt, inp_prompt + " ")

        """Text to SQL query generation"""
        input_ids = self.tokenizer(
            prompt, return_tensors="pt", truncation=True
        ).input_ids.to(next(self.model.parameters()).device)  # Move input to the device of the model
        outputs = self.model.generate(input_ids=input_ids, max_new_tokens=200)
        response = self.tokenizer.batch_decode(
            outputs.detach().cpu().numpy(), skip_special_tokens=True
        )[0][:]
        return response[len(prompt):]

    def text2sql_chatgpt(self, schema, user_prompt, inp_prompt=None):
        table_columns = ', '.join([f"{col[1]} {col[2]}" for col in schema])

        prompt = f"""Below are SQL table schemas paired with instructions that describe a task.
        Using valid SQLite, write a response that appropriately completes the request for the provided tables.
        Select all columns unless specified in specific.
        Example row :1,Michael,michael@ignitarium.com,59,Female,Headphones,2023-01-03,General inquiry,Email,44 hours,88 hours,4,Technical Issue,Server crashes due to memory leaks in custom-developed software.,Closed,"Debug and optimize the software code to identify and fix memory leaks, and implement regular monitoring for early detection.",Medium
        ### Instruction: {user_prompt} ### 
        Input: CREATE TABLE ticket_dataset({table_columns});
        ### Response: (Return only generated query based on user_prompt , nothing extra)"""

        if inp_prompt is not None :
            prompt = prompt.replace(user_prompt, inp_prompt + " ")
        else:
            inp_prompt = input("Press Enter for default question or Enter user prompt without newline characters: ").strip()
            if inp_prompt:
                prompt = prompt.replace(user_prompt, inp_prompt + " ")
        print(prompt)
        completion = self.chatgpt_client.chat.completions.create(
            model="gpt-3.5-turbo",
            messages=[
                {"role": "system", "content": "You are a expert SQL developer , generate a sql query and return it"},
                {"role": "user", "content": prompt }
            ]
        )
        return completion.choices[0].message.content

    def text2sql_gemini(self, schema, user_prompt, inp_prompt=None):
        table_columns = ', '.join([f"{col[1]} {col[2]}" for col in schema])

        prompt = f"""Below are SQL table schemas paired with instructions that describe a task.
        Using valid SQLite, write a response that appropriately completes the request for the provided tables.
        Select all columns unless specified in specific.
        Example row :1,Michael,michael@ignitarium.com,59,Female,Headphones,2023-01-03,General inquiry,Email,44 hours,88 hours,4,Technical Issue,Server crashes due to memory leaks in custom-developed software.,Closed,"Debug and optimize the software code to identify and fix memory leaks, and implement regular monitoring for early detection.",Medium
        ### Instruction: {user_prompt} ### 
        Input: CREATE TABLE ticket_dataset({table_columns});
        ### Response: (Return only generated query based on user_prompt , nothing extra)"""

        if inp_prompt is not None :
            prompt = prompt.replace(user_prompt, inp_prompt + " ")
        else:
            inp_prompt = input("Press Enter for default question or Enter user prompt without newline characters: ").strip()
            if inp_prompt:
                prompt = prompt.replace(user_prompt, inp_prompt + " ")
        print(prompt)
        completion = self.genai_model.generate_content(prompt)
        generated_query=completion.text
        start_index = generated_query.find("SELECT")
        end_index = generated_query.find(";", start_index) + 1
        print(start_index,end_index)
        if start_index != -1 and end_index != 0:
            return generated_query[start_index:end_index]
        else:
            return generated_query
        


    def execute_query(self, query):
        """Executing the query on database and returning rows and columns."""
        print(query)
        cur = self.conn.cursor()
        cur.execute(query)
        col = [header[0] for header in cur.description]
        dash = "-" * sum(len(col_name) + 4 for col_name in col)
        print(tuple(col))
        print(dash)
        rows = []
        for member in cur:
            rows.append(member)
            print(member)
        cur.close()
        self.conn.commit()
        # print(rows)
        return rows, col

if __name__ == "__main__":
    model_dir = "multi_table_demo/checkpoint-2600"
    database = r"ticket_dataset.db"
    sql_model = SQLPromptModel(model_dir, database)
    user_prompt = "Give complete details of properties in India"
    while True:
        table_schema = sql_model.fetch_table_schema("ticket_dataset")
        if table_schema:
            # query = sql_model.text2sql(table_schema, user_prompt)
            # query = sql_model.text2sql_chatgpt(table_schema, user_prompt)
            query = sql_model.text2sql_gemini(table_schema, user_prompt)
            print(query)
            sql_model.execute_query(query)
            
    sql_model.conn.close()