File size: 2,407 Bytes
e67043b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import json

import json
from difflib import ndiff
import psycopg2
import time

"""
prepare the test samples
"""


def execute_sql(sql):
    conn = psycopg2.connect(
        database="tpch10x", user="xxx", password="xxx", host="xxx", port=xxx
    )

    cur = conn.cursor()
    cur.execute(sql)
    # res = cur.fetchall()[0][0][0]
    res = cur.fetchall()

    conn.commit()
    cur.close()
    conn.close()

    return len(res)


# Load the JSON file as a dictionary
data = {}
with open("text2res_single_table.json", "r") as f:
    data = json.load(f)

# Select only the diverse SQL statements
# Find SQL statements with an edit distance of less than 10
selected_sql = []
for sql1 in data:
    if "sql" in sql1:
        sql1 = sql1["sql"]
        print("==========sql", sql1)
        start_time = time.time()
        res_cnt = execute_sql(sql1)
        end_time = time.time()
        elapsed_time = end_time - start_time

        print(res_cnt, elapsed_time)

        selected_sql.append(
            {f"sql": sql1, "res_cnt": res_cnt, "execution_time": elapsed_time}
        )


# Write the dictionary to a JSON file
with open("text2res_single_table2.json", "w") as f:
    json.dump(selected_sql, f)


"""
add text descriptions for queries
"""
if __name__ == "__main__":
    llm = LLM()  # add the def of your llm

    with open("./tpch10x/text2res_single_table2.json", "r") as json_file:
        json_data = json.load(json_file)

    new_json_data = []
    for i, item in enumerate(json_data):
        sql = item["sql"]
        print("========= ", i, sql)
        prompt = (
            "Please convert the following sql query into one natural language sentence: \n"
            + sql
            + "\n Note. 1) Do not mention any other information other than the natural language sentence; 2) Must use the origin table and column names in the sql query."
        )
        text = llm(prompt)
        item["text"] = text
        new_json_data.append(item)
        # print(llm("Describe Shanghai in 200 words."))

    with open("text2res_single_table3.json", "w") as f:
        json.dump(new_json_data, f)


"""
calculate total execution time
"""

with open("text2res_origin.json", "r") as json_file:
    json_data = json.load(json_file)

total_time = 0

for i, item in enumerate(json_data):
    print(item["execution_time"])
    total_time = total_time + float(item["execution_time"])

print(total_time)