File size: 6,417 Bytes
57816fc
 
8a8e4f3
1d0f3f8
8a8e4f3
57816fc
1d0f3f8
a148b10
8a8e4f3
1d0f3f8
 
a148b10
 
 
 
1d0f3f8
 
 
 
 
 
 
 
 
a148b10
 
 
 
1d0f3f8
 
a148b10
 
1d0f3f8
a148b10
 
57816fc
1d0f3f8
57816fc
 
1d0f3f8
57816fc
a148b10
 
57816fc
 
 
 
 
1d0f3f8
 
57816fc
a148b10
 
1d0f3f8
 
 
 
57816fc
b3eb06a
8a8e4f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d0f3f8
b3eb06a
57816fc
 
 
1d0f3f8
 
 
 
 
 
57816fc
 
1d0f3f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57816fc
 
1d0f3f8
 
 
57816fc
 
 
1d0f3f8
b3eb06a
1d0f3f8
b3eb06a
1d0f3f8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import streamlit as st
import requests
import subprocess
import re
import sys

PROMPT_TEMPLATE = """### Instruction:\n{instruction}\n\n### Input:\n{input}\n### Question:\n{question}\n\n### Response (use duckdb shorthand if possible):\n"""
INSTRUCTION_TEMPLATE = """Your task is to generate valid duckdb SQL to answer the following question{has_schema}"""  # noqa: E501
ERROR_MESSAGE = ":red[ Quack! Much to our regret, SQL generation has gone a tad duck-side-down.\nThe model is currently not able to craft a correct SQL query for this request. \nSorry my duck friend. ]\n\n:red[ Try rephrasing the question/instruction. And if the question is about your own database, make sure to set the correct schema. ]\n\n```sql\n{sql_query}\n```\n\n```sql\n{error_msg}\n```"
STOP_TOKENS = ["###", ";", "--", "```"]


def generate_prompt(question, schema):
    input = ""
    if schema:
        # Lowercase types inside each CREATE TABLE (...) statement
        for create_table in re.findall(
                r"CREATE TABLE [^(]+\((.*?)\);", schema, flags=re.DOTALL | re.MULTILINE
        ):
            for create_col in re.findall(r"(\w+) (\w+)", create_table):
                schema = schema.replace(
                    f"{create_col[0]} {create_col[1]}",
                    f"{create_col[0]} {create_col[1].lower()}",
                )
        input = """Here is the database schema that the SQL query will run on:\n{schema}\n""".format(  # noqa: E501
            schema=schema
        )
    prompt = PROMPT_TEMPLATE.format(
        instruction=INSTRUCTION_TEMPLATE.format(
            has_schema="." if schema == "" else ", given a duckdb database schema."
        ),
        input=input,
        question=question,
    )
    return prompt


def generate_sql(question, schema):
    prompt = generate_prompt(question, schema)

    s = requests.Session()
    api_base = "https://text-motherduck-sql-fp16-4vycuix6qcp2.octoai.run"
    url = f"{api_base}/v1/completions"
    body = {
        "model": "motherduck-sql-fp16",
        "prompt": prompt,
        "temperature": 0.1,
        "max_tokens": 200,
        "stop": "<s>",
        "n": 1,
    }
    headers = {"Authorization": f"Bearer {st.secrets['octoml_token']}"}
    with s.post(url, json=body, headers=headers) as resp:
        sql_query = resp.json()["choices"][0]["text"]

    return sql_query


def validate_sql(query, schema):
    try:
        # Define subprocess
        process = subprocess.Popen(
            [sys.executable, './validate_sql.py', query, schema],
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE
        )
        # Get output and potential parser, and binder error message
        stdout, stderr = process.communicate(timeout=0.5)
        if stderr:
            return False, stderr.decode('utf8')
        return True, ""
    except subprocess.TimeoutExpired:
        process.kill()
        # timeout reached, so parsing and binding was very likely successful
        return True, ""


st.title("DuckDB-NSQL-7B Demo")

expander = st.expander("Customize Schema (Optional)")
expander.markdown(
    "If you DuckDB database is `database.duckdb`, execute this query in your terminal to get your current schema:"
)
expander.markdown(
    """```bash\necho ".schema" | duckdb database.duckdb | sed 's/(/(\\n    /g' | sed 's/, /,\\n    /g' | sed 's/);/\\n);\\n/g'\n```""",
)

# Input field for text prompt
default_schema = """CREATE TABLE rideshare(
    hvfhs_license_num VARCHAR,
    dispatching_base_num VARCHAR,
    originating_base_num VARCHAR,
    request_datetime TIMESTAMP,
    on_scene_datetime TIMESTAMP,
    pickup_datetime TIMESTAMP,
    dropoff_datetime TIMESTAMP,
    PULocationID BIGINT,
    DOLocationID BIGINT,
    trip_miles DOUBLE,
    trip_time BIGINT,
    base_passenger_fare DOUBLE,
    tolls DOUBLE,
    bcf DOUBLE,
    sales_tax DOUBLE,
    congestion_surcharge DOUBLE,
    airport_fee DOUBLE,
    tips DOUBLE,
    driver_pay DOUBLE,
    shared_request_flag VARCHAR,
    shared_match_flag VARCHAR,
    access_a_ride_flag VARCHAR,
    wav_request_flag VARCHAR,
    wav_match_flag VARCHAR
);

CREATE TABLE service_requests(
    unique_key BIGINT,
    created_date TIMESTAMP,
    closed_date TIMESTAMP,
    agency VARCHAR,
    agency_name VARCHAR,
    complaint_type VARCHAR,
    descriptor VARCHAR,
    location_type VARCHAR,
    incident_zip VARCHAR,
    incident_address VARCHAR,
    street_name VARCHAR,
    cross_street_1 VARCHAR,
    cross_street_2 VARCHAR,
    intersection_street_1 VARCHAR,
    intersection_street_2 VARCHAR,
    address_type VARCHAR,
    city VARCHAR,
    landmark VARCHAR,
    facility_type VARCHAR,
    status VARCHAR,
    due_date TIMESTAMP,
    resolution_description VARCHAR,
    resolution_action_updated_date TIMESTAMP,
    community_board VARCHAR,
    bbl VARCHAR,
    borough VARCHAR,
    x_coordinate_state_plane VARCHAR,
    y_coordinate_state_plane VARCHAR,
    open_data_channel_type VARCHAR,
    park_facility_name VARCHAR,
    park_borough VARCHAR,
    vehicle_type VARCHAR,
    taxi_company_borough VARCHAR,
    taxi_pick_up_location VARCHAR,
    bridge_highway_name VARCHAR,
    bridge_highway_direction VARCHAR,
    road_ramp VARCHAR,
    bridge_highway_segment VARCHAR,
    latitude DOUBLE,
    longitude DOUBLE
);

CREATE TABLE taxi(
    VendorID BIGINT,
    tpep_pickup_datetime TIMESTAMP,
    tpep_dropoff_datetime TIMESTAMP,
    passenger_count DOUBLE,
    trip_distance DOUBLE,
    RatecodeID DOUBLE,
    store_and_fwd_flag VARCHAR,
    PULocationID BIGINT,
    DOLocationID BIGINT,
    payment_type BIGINT,
    fare_amount DOUBLE,
    extra DOUBLE,
    mta_tax DOUBLE,
    tip_amount DOUBLE,
    tolls_amount DOUBLE,
    improvement_surcharge DOUBLE,
    total_amount DOUBLE,
    congestion_surcharge DOUBLE,
    airport_fee DOUBLE,
    drivers VARCHAR[],
    speeding_tickets STRUCT(date TIMESTAMP, speed VARCHAR)[],
    other_violations JSON
);"""
schema = expander.text_area("Current schema:", value=default_schema, height=500)

# Input field for text prompt
text_prompt = st.text_input(
    "What DuckDB SQL query can I write for you?", value="Read a CSV file from test.csv"
)

if text_prompt:
    sql_query = generate_sql(text_prompt, schema)
    valid, msg = validate_sql(sql_query, schema)
    if not valid:
        st.markdown(ERROR_MESSAGE.format(sql_query=sql_query, error_msg=msg))
    else:
        st.markdown(f"""```sql\n{sql_query}\n```""")