File size: 7,893 Bytes
2714773
 
 
 
5782838
9d6743e
2714773
 
 
338e482
d7a34dd
2714773
d7a34dd
4050efc
5eca98b
9d6743e
2714773
 
ca38751
2714773
 
 
4050efc
2714773
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8233187
4050efc
115834e
8233187
01b0bb9
 
 
 
 
 
 
 
 
b84957a
7931b6e
 
01b0bb9
b84957a
115834e
10b34aa
a582020
10b34aa
01b0bb9
2714773
 
 
 
115834e
2714773
 
 
 
 
 
 
aa595f0
 
4da7ef5
2714773
 
 
 
 
4c1270f
 
 
4da7ef5
2714773
29d45a4
2714773
978fd4d
2714773
 
 
6651d18
2714773
 
 
 
 
 
 
 
 
5758bb4
2714773
 
 
 
 
 
 
 
0f46be8
2714773
 
6931a12
2714773
5758bb4
2714773
 
fd2320e
ccddcad
29d45a4
8233187
 
115834e
8233187
 
 
 
 
 
 
 
a46b806
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b2f839
 
8233187
115834e
10b34aa
 
0bb8f5d
 
2714773
29d45a4
 
 
115834e
d266202
fd2320e
 
20a1bcd
fd2320e
4050efc
fd2320e
d266202
 
ccddcad
bbb2bb6
ccddcad
 
 
 
d266202
2714773
 
978fd4d
5758bb4
 
 
 
 
 
2714773
 
f216d8b
2714773
 
 
82f4770
29d45a4
 
 
 
9c35ba8
 
2714773
 
 
f216d8b
4fc0d85
da85880
ccddcad
 
 
 
 
4fc0d85
2714773
 
 
01b0bb9
 
2714773
01b0bb9
2714773
 
 
d266202
4050efc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import TapexTokenizer, BartForConditionalGeneration
import pandas as pd
import gradio as gr

import numpy as np
import time
import os
import random

#import pyodbc

'''
import pkg_resources

# Get a list of installed packages and their versions
installed_packages = {pkg.key: pkg.version for pkg in pkg_resources.working_set}

# Print the list of packages
for package, version in installed_packages.items():
    print(f"{package}=={version}")
'''

'''
# Replace the connection parameters with your SQL Server information
server = 'your_server'
database = 'your_database'
username = 'your_username'
password = 'your_password'
driver = 'SQL Server'  # This depends on the ODBC driver installed on your system

# Create the connection string
connection_string = f'DRIVER={{{driver}}};SERVER={server};DATABASE={database};UID={username};PWD={password}'

# Connect to the SQL Server
conn = pyodbc.connect(connection_string)

#============================================================================
# Replace "your_query" with your SQL query to fetch data from the database
query = 'SELECT * FROM your_table_name'

# Use pandas to read data from the SQL Server and store it in a DataFrame
df = pd.read_sql_query(query, conn)

# Close the SQL connection
conn.close()
'''


# Create a sample DataFrame with 3,000 records and 20 columns
'''
num_records = 3000
num_columns = 20

data = {
    f"column_{i}": np.random.randint(0, 100, num_records) for i in range(num_columns)
}

# Randomize the year and city columns
years = list(range(2000, 2023))  # Range of years
cities = ["New York", "Los Angeles", "Chicago", "Houston", "Miami"]  # List of cities

data["year"] = [random.choice(years) for _ in range(num_records)]
data["city"] = [random.choice(cities) for _ in range(num_records)]

table = pd.DataFrame(data)
'''
#table = pd.read_csv(csv_file.name, delimiter=",")
#table.fillna(0, inplace=True)
#table = table.astype(str)

data = {
    "year": [1896, 1900, 1904, 2004, 2008, 2012],
    "city": ["athens", "paris", "st. louis", "athens", "beijing", "london"]
}
table = pd.DataFrame.from_dict(data)


# Load the chatbot model
chatbot_model_name = "microsoft/DialoGPT-medium" 
tokenizer = AutoTokenizer.from_pretrained(chatbot_model_name)
model = AutoModelForCausalLM.from_pretrained(chatbot_model_name)

#cmax_token_limit = tokenizer.max_model_input_sizes[chatbot_model_name]
#print(f"Chat bot Maximum token limit for {chatbot_model_name}: {cmax_token_limit}")

# Load the SQL Model
sql_model_name = "microsoft/tapex-large-finetuned-wtq"
sql_tokenizer = TapexTokenizer.from_pretrained(sql_model_name)
sql_model = BartForConditionalGeneration.from_pretrained(sql_model_name)

#stokenizer = AutoTokenizer.from_pretrained(sql_model_name)
#max_token_limit = stokenizer.max_model_input_sizes[sql_model_name]
#print(f"SQL Maximum token limit for {sql_model_name}: {max_token_limit}")

#sql_response = None
conversation_history = []

def chat(input, history=[]):

    #global sql_response
    # Check if the user input is a question
    #is_question = "?" in input

    '''
    if is_question: 
        sql_encoding = sql_tokenizer(table=table, query=input + sql_tokenizer.eos_token, return_tensors="pt")
        sql_outputs = sql_model.generate(**sql_encoding)
        sql_response = sql_tokenizer.batch_decode(sql_outputs, skip_special_tokens=True)

    else:
    '''
    
    # tokenize the new input sentence
    new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')
    
    # append the new user input tokens to the chat history
    bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
      
    # generate a response
    history = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id).tolist()

    # convert the tokens to text, and then split the responses into the right format
    response = tokenizer.decode(history[0]).split("<|endoftext|>")
    response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)]  # convert to tuples of list
    
    return response, history


def sqlquery(input): #, history=[]):

    global conversation_history

    #======================================================================
    '''
    batch_size = 10  # Number of records in each batch
    num_records = 3000  # Total number of records in the dataset
    for start_idx in range(0, num_records, batch_size):
        end_idx = min(start_idx + batch_size, num_records)
        
        # Get a batch of records
        batch_data = table[start_idx:end_idx]

        # Tokenize the batch
        tokenized_batch = sql_tokenizer.batch_encode_plus(
            batch_data, padding=True, truncation=True, return_tensors="pt"
        )
        
        # Perform inference
        with torch.no_grad():
            output = sql_model.generate(
                input_ids=tokenized_batch["input_ids"],
                max_length=1024,
                pad_token_id=sql_tokenizer.eos_token_id,
            )
        
        # Decode the output and process the responses
        responses = [sql_tokenizer.decode(ids, skip_special_tokens=True) for ids in output]

        conversation_history.append("User: " + record["question"])
        for response in enumerate(responses):
            # Update conversation history            
            conversation_history.append("Bot: " + response)
    '''       
       
    # ==========================================================================
    
    inputs = [input]
    sql_encoding = sql_tokenizer(table=table, query=input, return_tensors="pt")
    sql_outputs = sql_model.generate(**sql_encoding)
    sql_response = sql_tokenizer.batch_decode(sql_outputs, skip_special_tokens=True)

    #history.append((input, sql_response))
    conversation_history.append(("User", input))
    conversation_history.append(("Bot", sql_response))
    
    # Build conversation string
    #conversation = "\n".join([f"User: {user_msg}\nBot: {resp_msg}" for user_msg, resp_msg in conversation_history])
    conversation = "\n".join([f"{sender}: {msg}" for sender, msg in conversation_history])
    
    return conversation
    #return sql_response
    #return sql_response, history

    '''
    html = "<div class='chatbot'>"
    for user_msg, resp_msg in conversation_history:
        html += f"<div class='user_msg'>{user_msg}</div>"
        html += f"<div class='resp_msg'>{resp_msg}</div>"
    html += "</div>"
    return html
    '''

chat_interface = gr.Interface(
    fn=chat,
    theme="default",
    css=".footer {display:none !important}",
    inputs=["text", "state"],
    outputs=["chatbot", "state"],
    title="ST Chatbot",
    description="Type your message in the box above, and the chatbot will respond.",
)


sql_interface = gr.Interface(
    fn=sqlquery,
    theme="default",
    css=".footer {display:none !important}",
    inputs=gr.Textbox(prompt="You:"),
    outputs=gr.Textbox(),
    #inputs=["text", "state"],
    #outputs=["chatbot", "state"],
    #live=True,
    #capture_session=True,
    title="ST SQL Chat",
    description="Type your message in the box above, and the chatbot will respond.",
)

'''
iface = gr.Interface(sqlquery, "text", "html", css="""
    .chatbox {display:flex;flex-direction:column}
    .user_msg, .resp_msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%}
    .user_msg {background-color:cornflowerblue;color:white;align-self:start}
    .resp_msg {background-color:lightgray;align-self:self-end}
""", allow_screenshot=False, allow_flagging=False)
'''

combine_interface = gr.TabbedInterface(
    interface_list=[
        sql_interface,
        chat_interface
    ],
    tab_names=['SQL Chat' ,'Chatbot'],
)

if __name__ == '__main__':
    combine_interface.launch()
    #iface.launch(debug=True)