File size: 7,056 Bytes
7de61f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import sqlite3
from sqlite3 import Error
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer, BitsAndBytesConfig
from transformers import AutoModelForCausalLM
from openai import OpenAI
import google.generativeai as genai

class SQLPromptModel:
    def __init__(self, model_dir, database):
        self.model_dir = model_dir
        self.database = database
        # peft_model_dir = self.model_dir
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype="float16",
            bnb_4bit_use_double_quant=True,
        )
        # self.model = AutoPeftModelForCausalLM.from_pretrained(
        #     peft_model_dir, low_cpu_mem_usage=True, quantization_config=bnb_config
        # )
        # self.tokenizer = AutoTokenizer.from_pretrained(peft_model_dir)
        # self.model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
        # self.tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
        self.chatgpt_client = OpenAI(api_key="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")
        self.genai = genai
        self.genai.configure(api_key="AIzaSyAFG94rVbm9eWepO5uPGsMha8XJ-sHbMdA")
        self.genai_model = genai.GenerativeModel('gemini-pro')

        self.conn = sqlite3.connect(self.database)

    def fetch_table_schema(self, table_name):
        """Fetch the schema of a table from the database."""
        cursor = self.conn.cursor()
        cursor.execute(f"PRAGMA table_info({table_name})")
        schema = cursor.fetchall()
        if schema:
            return schema
        else:
            print(f"Table {table_name} does not exist or has no schema.")
            return None
        
    def text2sql(self, schema, user_prompt, inp_prompt=None):
        """Generate SQL query based on user prompt and table schema.inp_prompt is for gradio purpose"""
        table_columns = ', '.join([f"{col[1]} {col[2]}" for col in schema])

        prompt = f"""Below are SQL table schemas paired with instructions that describe a task.
        Using valid SQLite, write a response that appropriately completes the request for the provided tables.
        ### Instruction: {user_prompt} ### 
        Input: CREATE TABLE sql_pdf({table_columns});
        ### Response: (Return only query , nothing extra)"""

        if inp_prompt is not None :
            prompt = prompt.replace(user_prompt, inp_prompt + " ")
        else:
            inp_prompt = input("Press Enter for default question or Enter user prompt without newline characters: ").strip()
            if inp_prompt:
                prompt = prompt.replace(user_prompt, inp_prompt + " ")

        """Text to SQL query generation"""
        input_ids = self.tokenizer(
            prompt, return_tensors="pt", truncation=True
        ).input_ids.to(next(self.model.parameters()).device)  # Move input to the device of the model
        outputs = self.model.generate(input_ids=input_ids, max_new_tokens=200)
        response = self.tokenizer.batch_decode(
            outputs.detach().cpu().numpy(), skip_special_tokens=True
        )[0][:]
        return response[len(prompt):]

    def text2sql_chatgpt(self, schema, user_prompt, inp_prompt=None):
        table_columns = ', '.join([f"{col[1]} {col[2]}" for col in schema])

        prompt = f"""Below are SQL table schemas paired with instructions that describe a task.
        Using valid SQLite, write a response that appropriately completes the request for the provided tables.
        ### Instruction: {user_prompt} ### 
        Input: CREATE TABLE sql_pdf({table_columns});
        ### Response: (Return only generated query based on user_prompt , nothing extra)"""

        if inp_prompt is not None :
            prompt = prompt.replace(user_prompt, inp_prompt + " ")
        else:
            inp_prompt = input("Press Enter for default question or Enter user prompt without newline characters: ").strip()
            if inp_prompt:
                prompt = prompt.replace(user_prompt, inp_prompt + " ")
        print(prompt)
        completion = self.chatgpt_client.chat.completions.create(
            model="gpt-3.5-turbo",
            messages=[
                {"role": "system", "content": "You are a expert SQL developer , generate a sql query and return it"},
                {"role": "user", "content": prompt }
            ]
        )
        return completion.choices[0].message.content

    def text2sql_gemini(self, schema, user_prompt, inp_prompt=None):
        table_columns = ', '.join([f"{col[1]} {col[2]}" for col in schema])

        prompt = f"""Below are SQL table schemas paired with instructions that describe a task.
        Using valid SQLite, write a response that appropriately completes the request for the provided tables.
        ### Instruction: {user_prompt} ### 
        Input: CREATE TABLE sql_pdf({table_columns});
        ### Response: (Return only generated query based on user_prompt , nothing extra)"""

        if inp_prompt is not None :
            prompt = prompt.replace(user_prompt, inp_prompt + " ")
        else:
            inp_prompt = input("Press Enter for default question or Enter user prompt without newline characters: ").strip()
            if inp_prompt:
                prompt = prompt.replace(user_prompt, inp_prompt + " ")
        print(prompt)
        completion = self.genai_model.generate_content(prompt)
        generated_query=completion.text
        start_index = generated_query.find("SELECT")
        end_index = generated_query.find(";", start_index) + 1
        print(start_index,end_index)
        if start_index != -1 and end_index != 0:
            return generated_query[start_index:end_index]
        else:
            return generated_query
        


    def execute_query(self, query):
        """Executing the query on database and returning rows and columns."""
        print(query)
        cur = self.conn.cursor()
        cur.execute(query)
        col = [header[0] for header in cur.description]
        dash = "-" * sum(len(col_name) + 4 for col_name in col)
        print(tuple(col))
        print(dash)
        rows = []
        for member in cur:
            rows.append(member)
            print(member)
        cur.close()
        self.conn.commit()
        # print(rows)
        return rows, col

if __name__ == "__main__":
    model_dir = "multi_table_demo/checkpoint-2600"
    database = r"sql_pdf.db"
    sql_model = SQLPromptModel(model_dir, database)
    user_prompt = "Give complete details of properties in India"
    while True:
        table_schema = sql_model.fetch_table_schema("sql_pdf")
        if table_schema:
            # query = sql_model.text2sql(table_schema, user_prompt)
            # query = sql_model.text2sql_chatgpt(table_schema, user_prompt)
            query = sql_model.text2sql_gemini(table_schema, user_prompt)
            print(query)
            sql_model.execute_query(query)
            
    sql_model.conn.close()