Update app.py
Browse files
app.py
CHANGED
@@ -47,9 +47,9 @@ df = pd.read_sql_query(query, conn)
|
|
47 |
conn.close()
|
48 |
'''
|
49 |
|
50 |
-
|
51 |
# Create a sample DataFrame with 3,000 records and 20 columns
|
52 |
-
num_records =
|
53 |
num_columns = 20
|
54 |
|
55 |
data = {
|
@@ -64,7 +64,7 @@ data["year"] = [random.choice(years) for _ in range(num_records)]
|
|
64 |
data["city"] = [random.choice(cities) for _ in range(num_records)]
|
65 |
|
66 |
table = pd.DataFrame(data)
|
67 |
-
|
68 |
#table = pd.read_csv(csv_file.name, delimiter=",")
|
69 |
#table.fillna(0, inplace=True)
|
70 |
#table = table.astype(str)
|
@@ -73,7 +73,7 @@ data = {
|
|
73 |
"year": [1896, 1900, 1904, 2004, 2008, 2012],
|
74 |
"city": ["athens", "paris", "st. louis", "athens", "beijing", "london"]
|
75 |
}
|
76 |
-
table = pd.DataFrame.from_dict(data)
|
77 |
|
78 |
|
79 |
# Load the chatbot model
|
@@ -130,7 +130,48 @@ def chat(input, history=[]):
|
|
130 |
def sqlquery(input): #, history=[]):
|
131 |
|
132 |
global conversation_history
|
133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
inputs = [input]
|
135 |
sql_encoding = sql_tokenizer(table=table, query=input, return_tensors="pt")
|
136 |
sql_outputs = sql_model.generate(**sql_encoding)
|
@@ -139,7 +180,7 @@ def sqlquery(input): #, history=[]):
|
|
139 |
#history.append((input, sql_response))
|
140 |
conversation_history.append(("User", input))
|
141 |
conversation_history.append(("Bot", sql_response))
|
142 |
-
|
143 |
# Build conversation string
|
144 |
#conversation = "\n".join([f"User: {user_msg}\nBot: {resp_msg}" for user_msg, resp_msg in conversation_history])
|
145 |
conversation = "\n".join([f"{sender}: {msg}" for sender, msg in conversation_history])
|
|
|
47 |
conn.close()
|
48 |
'''
|
49 |
|
50 |
+
|
51 |
# Create a sample DataFrame with 3,000 records and 20 columns
|
52 |
+
num_records = 3000
|
53 |
num_columns = 20
|
54 |
|
55 |
data = {
|
|
|
64 |
data["city"] = [random.choice(cities) for _ in range(num_records)]
|
65 |
|
66 |
table = pd.DataFrame(data)
|
67 |
+
|
68 |
#table = pd.read_csv(csv_file.name, delimiter=",")
|
69 |
#table.fillna(0, inplace=True)
|
70 |
#table = table.astype(str)
|
|
|
73 |
"year": [1896, 1900, 1904, 2004, 2008, 2012],
|
74 |
"city": ["athens", "paris", "st. louis", "athens", "beijing", "london"]
|
75 |
}
|
76 |
+
#table = pd.DataFrame.from_dict(data)
|
77 |
|
78 |
|
79 |
# Load the chatbot model
|
|
|
130 |
def sqlquery(input): #, history=[]):
|
131 |
|
132 |
global conversation_history
|
133 |
+
|
134 |
+
#======================================================================
|
135 |
+
batch_size = 10 # Number of records in each batch
|
136 |
+
num_records = 3000 # Total number of records in the dataset
|
137 |
+
for start_idx in range(0, num_records, batch_size):
|
138 |
+
end_idx = min(start_idx + batch_size, num_records)
|
139 |
+
|
140 |
+
# Get a batch of records
|
141 |
+
batch_data = table[start_idx:end_idx]
|
142 |
+
|
143 |
+
batch_responses = []
|
144 |
+
|
145 |
+
for idx, record in enumerate(batch_data):
|
146 |
+
# Maintain conversation context by appending history
|
147 |
+
if conversation_history:
|
148 |
+
history = "\n".join(conversation_history)
|
149 |
+
input_text = history + "\nUser: " + record["question"]
|
150 |
+
else:
|
151 |
+
input_text = "User: " + record["question"]
|
152 |
+
|
153 |
+
# Tokenize the input text
|
154 |
+
tokenized_input = sql_tokenizer.encode(input_text, return_tensors="pt")
|
155 |
+
|
156 |
+
# Perform inference
|
157 |
+
with torch.no_grad():
|
158 |
+
output = sql_model.generate(
|
159 |
+
input_ids=tokenized_input,
|
160 |
+
max_length=1024,
|
161 |
+
pad_token_id=sql_tokenizer.eos_token_id,
|
162 |
+
)
|
163 |
+
|
164 |
+
# Decode the output and process the response
|
165 |
+
response = sql_tokenizer.decode(output[0], skip_special_tokens=True)
|
166 |
+
batch_responses.append(response)
|
167 |
+
|
168 |
+
# Update conversation history
|
169 |
+
conversation_history.append("User: " + record["question"])
|
170 |
+
conversation_history.append("Bot: " + response)
|
171 |
+
|
172 |
+
|
173 |
+
# ==========================================================================
|
174 |
+
'''
|
175 |
inputs = [input]
|
176 |
sql_encoding = sql_tokenizer(table=table, query=input, return_tensors="pt")
|
177 |
sql_outputs = sql_model.generate(**sql_encoding)
|
|
|
180 |
#history.append((input, sql_response))
|
181 |
conversation_history.append(("User", input))
|
182 |
conversation_history.append(("Bot", sql_response))
|
183 |
+
'''
|
184 |
# Build conversation string
|
185 |
#conversation = "\n".join([f"User: {user_msg}\nBot: {resp_msg}" for user_msg, resp_msg in conversation_history])
|
186 |
conversation = "\n".join([f"{sender}: {msg}" for sender, msg in conversation_history])
|