teaevo commited on
Commit
8233187
1 Parent(s): a582020

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -6
app.py CHANGED
@@ -47,9 +47,9 @@ df = pd.read_sql_query(query, conn)
47
  conn.close()
48
  '''
49
 
50
- '''
51
  # Create a sample DataFrame with 3,000 records and 20 columns
52
- num_records = 100
53
  num_columns = 20
54
 
55
  data = {
@@ -64,7 +64,7 @@ data["year"] = [random.choice(years) for _ in range(num_records)]
64
  data["city"] = [random.choice(cities) for _ in range(num_records)]
65
 
66
  table = pd.DataFrame(data)
67
- '''
68
  #table = pd.read_csv(csv_file.name, delimiter=",")
69
  #table.fillna(0, inplace=True)
70
  #table = table.astype(str)
@@ -73,7 +73,7 @@ data = {
73
  "year": [1896, 1900, 1904, 2004, 2008, 2012],
74
  "city": ["athens", "paris", "st. louis", "athens", "beijing", "london"]
75
  }
76
- table = pd.DataFrame.from_dict(data)
77
 
78
 
79
  # Load the chatbot model
@@ -130,7 +130,48 @@ def chat(input, history=[]):
130
  def sqlquery(input): #, history=[]):
131
 
132
  global conversation_history
133
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  inputs = [input]
135
  sql_encoding = sql_tokenizer(table=table, query=input, return_tensors="pt")
136
  sql_outputs = sql_model.generate(**sql_encoding)
@@ -139,7 +180,7 @@ def sqlquery(input): #, history=[]):
139
  #history.append((input, sql_response))
140
  conversation_history.append(("User", input))
141
  conversation_history.append(("Bot", sql_response))
142
-
143
  # Build conversation string
144
  #conversation = "\n".join([f"User: {user_msg}\nBot: {resp_msg}" for user_msg, resp_msg in conversation_history])
145
  conversation = "\n".join([f"{sender}: {msg}" for sender, msg in conversation_history])
 
47
  conn.close()
48
  '''
49
 
50
+
51
  # Create a sample DataFrame with 3,000 records and 20 columns
52
+ num_records = 3000
53
  num_columns = 20
54
 
55
  data = {
 
64
  data["city"] = [random.choice(cities) for _ in range(num_records)]
65
 
66
  table = pd.DataFrame(data)
67
+
68
  #table = pd.read_csv(csv_file.name, delimiter=",")
69
  #table.fillna(0, inplace=True)
70
  #table = table.astype(str)
 
73
  "year": [1896, 1900, 1904, 2004, 2008, 2012],
74
  "city": ["athens", "paris", "st. louis", "athens", "beijing", "london"]
75
  }
76
+ #table = pd.DataFrame.from_dict(data)
77
 
78
 
79
  # Load the chatbot model
 
130
  def sqlquery(input): #, history=[]):
131
 
132
  global conversation_history
133
+
134
+ #======================================================================
135
+ batch_size = 10 # Number of records in each batch
136
+ num_records = 3000 # Total number of records in the dataset
137
+ for start_idx in range(0, num_records, batch_size):
138
+ end_idx = min(start_idx + batch_size, num_records)
139
+
140
+ # Get a batch of records
141
+ batch_data = table[start_idx:end_idx]
142
+
143
+ batch_responses = []
144
+
145
+ for idx, record in enumerate(batch_data):
146
+ # Maintain conversation context by appending history
147
+ if conversation_history:
148
+ history = "\n".join(conversation_history)
149
+ input_text = history + "\nUser: " + record["question"]
150
+ else:
151
+ input_text = "User: " + record["question"]
152
+
153
+ # Tokenize the input text
154
+ tokenized_input = sql_tokenizer.encode(input_text, return_tensors="pt")
155
+
156
+ # Perform inference
157
+ with torch.no_grad():
158
+ output = sql_model.generate(
159
+ input_ids=tokenized_input,
160
+ max_length=1024,
161
+ pad_token_id=sql_tokenizer.eos_token_id,
162
+ )
163
+
164
+ # Decode the output and process the response
165
+ response = sql_tokenizer.decode(output[0], skip_special_tokens=True)
166
+ batch_responses.append(response)
167
+
168
+ # Update conversation history
169
+ conversation_history.append("User: " + record["question"])
170
+ conversation_history.append("Bot: " + response)
171
+
172
+
173
+ # ==========================================================================
174
+ '''
175
  inputs = [input]
176
  sql_encoding = sql_tokenizer(table=table, query=input, return_tensors="pt")
177
  sql_outputs = sql_model.generate(**sql_encoding)
 
180
  #history.append((input, sql_response))
181
  conversation_history.append(("User", input))
182
  conversation_history.append(("Bot", sql_response))
183
+ '''
184
  # Build conversation string
185
  #conversation = "\n".join([f"User: {user_msg}\nBot: {resp_msg}" for user_msg, resp_msg in conversation_history])
186
  conversation = "\n".join([f"{sender}: {msg}" for sender, msg in conversation_history])