AnilNiraula commited on
Commit
f9c52da
·
verified ·
1 Parent(s): 74f4aad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -45
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import logging
2
  import os
 
3
  import torch
4
  import gradio as gr
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
@@ -45,6 +46,14 @@ response_cache = {
45
  "The best places to open a brokerage account include Vanguard, Fidelity, Charles Schwab, and Robinhood. "
46
  "They offer low fees, no minimums, and user-friendly platforms for beginners."
47
  ),
 
 
 
 
 
 
 
 
48
  }
49
 
50
  # Load persistent cache
@@ -73,13 +82,13 @@ except Exception as e:
73
  logger.error(f"Error loading model/tokenizer: {e}")
74
  raise RuntimeError(f"Failed to load model: {str(e)}")
75
 
76
- # Updated prompt prefix with better instructions examples
77
  prompt_prefix = (
78
  "You are FinChat, a financial advisor. Always provide clear, step-by-step answers to the user's exact question. "
79
  "Avoid vague or unrelated topics. Use a numbered list format where appropriate and explain each step.\n\n"
80
  "Example 1:\n"
81
  "Q: How can I start investing with $100 a month?\n"
82
- "A: Here’s a step-by-step guide:\n"
83
  "1. Open a brokerage account with a platform like Fidelity or Robinhood. They offer low fees and no minimums.\n"
84
  "2. Deposit your $100 monthly. You can set up automatic transfers.\n"
85
  "3. Choose a low-cost ETF like VOO, which tracks the S&P 500.\n"
@@ -97,10 +106,12 @@ def get_closest_cache_key(message, cache_keys, threshold=0.7):
97
  matches = difflib.get_close_matches(message, cache_keys, n=1, cutoff=threshold)
98
  return matches[0] if matches else None
99
 
100
- # Define chat function with updated generation parameters
101
  def chat_with_model(user_input, history=None):
102
  try:
 
103
  logger.info(f"Processing user input: {user_input}")
 
104
  cache_key = user_input.lower().strip()
105
  cache_keys = list(response_cache.keys())
106
  closest_key = cache_key if cache_key in response_cache else get_closest_cache_key(cache_key, cache_keys)
@@ -111,6 +122,8 @@ def chat_with_model(user_input, history=None):
111
  history = history or []
112
  history.append({"role": "user", "content": user_input})
113
  history.append({"role": "assistant", "content": response})
 
 
114
  return response, history
115
 
116
  if len(user_input.strip()) <= 5:
@@ -120,22 +133,26 @@ def chat_with_model(user_input, history=None):
120
  history = history or []
121
  history.append({"role": "user", "content": user_input})
122
  history.append({"role": "assistant", "content": response})
 
 
123
  return response, history
124
 
125
  full_prompt = prompt_prefix + user_input + "\nA:"
126
  inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=512).to(device)
127
 
128
  with torch.inference_mode():
 
129
  outputs = model.generate(
130
  **inputs,
131
- max_new_tokens=150, # Increased for longer responses
132
  min_length=20,
133
- do_sample=True,
134
- temperature=0.5, # Lowered for more focused responses
135
- top_p=0.9,
136
  repetition_penalty=1.2,
137
  pad_token_id=tokenizer.eos_token_id
138
  )
 
 
 
139
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
140
  response = response[len(full_prompt):].strip() if response.startswith(full_prompt) else response
141
  logger.info(f"Chatbot response: {response}")
@@ -147,6 +164,8 @@ def chat_with_model(user_input, history=None):
147
  history.append({"role": "user", "content": user_input})
148
  history.append({"role": "assistant", "content": response})
149
  torch.cuda.empty_cache()
 
 
150
  return response, history
151
  except Exception as e:
152
  logger.error(f"Error generating response: {e}")
@@ -157,34 +176,7 @@ def chat_with_model(user_input, history=None):
157
  history.append({"role": "assistant", "content": response})
158
  return response, history
159
 
160
- # Feedback logging functions
161
- def log_feedback_up(history):
162
- if history:
163
- last_user = history[-2]['content']
164
- last_assistant = history[-1]['content']
165
- feedback = {"question": last_user, "response": last_assistant, "feedback": "up"}
166
- try:
167
- with open("feedback.json", "a") as f:
168
- json.dump(feedback, f)
169
- f.write("\n")
170
- logger.info("Logged positive feedback")
171
- except Exception as e:
172
- logger.warning(f"Failed to log feedback: {e}")
173
-
174
- def log_feedback_down(history):
175
- if history:
176
- last_user = history[-2]['content']
177
- last_assistant = history[-1]['content']
178
- feedback = {"question": last_user, "response": last_assistant, "feedback": "down"}
179
- try:
180
- with open("feedback.json", "a") as f:
181
- json.dump(feedback, f)
182
- f.write("\n")
183
- logger.info("Logged negative feedback")
184
- except Exception as e:
185
- logger.warning(f"Failed to log feedback: {e}")
186
-
187
- # Create Gradio interface with feedback buttons
188
  with gr.Blocks(
189
  title="FinChat: An LLM based on distilgpt2 model",
190
  css=".feedback {display: flex; gap: 10px; justify-content: center; margin-top: 10px;}"
@@ -201,12 +193,6 @@ with gr.Blocks(
201
  submit = gr.Button("Send")
202
  clear = gr.Button("Clear")
203
 
204
- # Feedback section
205
- gr.Markdown("**Was this helpful?**")
206
- with gr.Row(elem_classes="feedback"):
207
- thumbs_up = gr.Button("👍")
208
- thumbs_down = gr.Button("👎")
209
-
210
  def submit_message(user_input, history):
211
  response, updated_history = chat_with_model(user_input, history)
212
  return "", updated_history # Clear input, update chatbot
@@ -220,15 +206,12 @@ with gr.Blocks(
220
  fn=lambda: ("", []), # Clear input and chatbot
221
  outputs=[msg, chatbot]
222
  )
223
-
224
- thumbs_up.click(fn=log_feedback_up, inputs=[chatbot], outputs=None)
225
- thumbs_down.click(fn=log_feedback_down, inputs=[chatbot], outputs=None)
226
 
227
  # Launch interface (conditional for Spaces)
228
  if __name__ == "__main__" and not os.getenv("HF_SPACE"):
229
  logger.info("Launching Gradio interface locally")
230
  try:
231
- interface.launch(share=False, debug=True)
232
  except Exception as e:
233
  logger.error(f"Error launching interface: {e}")
234
  raise
 
1
  import logging
2
  import os
3
+ import time # Added for timing logs
4
  import torch
5
  import gradio as gr
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
46
  "The best places to open a brokerage account include Vanguard, Fidelity, Charles Schwab, and Robinhood. "
47
  "They offer low fees, no minimums, and user-friendly platforms for beginners."
48
  ),
49
+ "what is dollar-cost averaging?": (
50
+ "Dollar-cost averaging is investing a fixed amount regularly (e.g., $100 monthly) in ETFs, "
51
+ "reducing risk by spreading purchases over time."
52
+ ),
53
+ "how much should i invest?": (
54
+ "Invest what you can afford after expenses and an emergency fund. Start with $100-$500 monthly "
55
+ "in ETFs like VOO using dollar-cost averaging. Consult a financial planner."
56
+ ),
57
  }
58
 
59
  # Load persistent cache
 
82
  logger.error(f"Error loading model/tokenizer: {e}")
83
  raise RuntimeError(f"Failed to load model: {str(e)}")
84
 
85
+ # Updated prompt prefix with better instructions and examples
86
  prompt_prefix = (
87
  "You are FinChat, a financial advisor. Always provide clear, step-by-step answers to the user's exact question. "
88
  "Avoid vague or unrelated topics. Use a numbered list format where appropriate and explain each step.\n\n"
89
  "Example 1:\n"
90
  "Q: How can I start investing with $100 a month?\n"
91
+ "A: Here’s a step-by point-by-step guide:\n"
92
  "1. Open a brokerage account with a platform like Fidelity or Robinhood. They offer low fees and no minimums.\n"
93
  "2. Deposit your $100 monthly. You can set up automatic transfers.\n"
94
  "3. Choose a low-cost ETF like VOO, which tracks the S&P 500.\n"
 
106
  matches = difflib.get_close_matches(message, cache_keys, n=1, cutoff=threshold)
107
  return matches[0] if matches else None
108
 
109
+ # Define chat function with optimized generation parameters
110
  def chat_with_model(user_input, history=None):
111
  try:
112
+ start_time = time.time() # Start timing
113
  logger.info(f"Processing user input: {user_input}")
114
+
115
  cache_key = user_input.lower().strip()
116
  cache_keys = list(response_cache.keys())
117
  closest_key = cache_key if cache_key in response_cache else get_closest_cache_key(cache_key, cache_keys)
 
122
  history = history or []
123
  history.append({"role": "user", "content": user_input})
124
  history.append({"role": "assistant", "content": response})
125
+ end_time = time.time()
126
+ logger.info(f"Response time: {end_time - start_time:.2f} seconds")
127
  return response, history
128
 
129
  if len(user_input.strip()) <= 5:
 
133
  history = history or []
134
  history.append({"role": "user", "content": user_input})
135
  history.append({"role": "assistant", "content": response})
136
+ end_time = time.time()
137
+ logger.info(f"Response time: {end_time - start_time:.2f} seconds")
138
  return response, history
139
 
140
  full_prompt = prompt_prefix + user_input + "\nA:"
141
  inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=512).to(device)
142
 
143
  with torch.inference_mode():
144
+ gen_start_time = time.time() # Start generation timing
145
  outputs = model.generate(
146
  **inputs,
147
+ max_new_tokens=75, # Reduced for faster generation
148
  min_length=20,
149
+ do_sample=False, # Use greedy decoding for speed
 
 
150
  repetition_penalty=1.2,
151
  pad_token_id=tokenizer.eos_token_id
152
  )
153
+ gen_end_time = time.time()
154
+ logger.info(f"Generation time: {gen_end_time - gen_start_time:.2f} seconds")
155
+
156
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
157
  response = response[len(full_prompt):].strip() if response.startswith(full_prompt) else response
158
  logger.info(f"Chatbot response: {response}")
 
164
  history.append({"role": "user", "content": user_input})
165
  history.append({"role": "assistant", "content": response})
166
  torch.cuda.empty_cache()
167
+ end_time = time.time()
168
+ logger.info(f"Total response time: {end_time - start_time:.2f} seconds")
169
  return response, history
170
  except Exception as e:
171
  logger.error(f"Error generating response: {e}")
 
176
  history.append({"role": "assistant", "content": response})
177
  return response, history
178
 
179
+ # Create Gradio interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  with gr.Blocks(
181
  title="FinChat: An LLM based on distilgpt2 model",
182
  css=".feedback {display: flex; gap: 10px; justify-content: center; margin-top: 10px;}"
 
193
  submit = gr.Button("Send")
194
  clear = gr.Button("Clear")
195
 
 
 
 
 
 
 
196
  def submit_message(user_input, history):
197
  response, updated_history = chat_with_model(user_input, history)
198
  return "", updated_history # Clear input, update chatbot
 
206
  fn=lambda: ("", []), # Clear input and chatbot
207
  outputs=[msg, chatbot]
208
  )
 
 
 
209
 
210
  # Launch interface (conditional for Spaces)
211
  if __name__ == "__main__" and not os.getenv("HF_SPACE"):
212
  logger.info("Launching Gradio interface locally")
213
  try:
214
+ interface.launch(share=False, debug=True)
215
  except Exception as e:
216
  logger.error(f"Error launching interface: {e}")
217
  raise