Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import logging
|
| 2 |
import os
|
|
|
|
| 3 |
import torch
|
| 4 |
import gradio as gr
|
| 5 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
@@ -45,6 +46,14 @@ response_cache = {
|
|
| 45 |
"The best places to open a brokerage account include Vanguard, Fidelity, Charles Schwab, and Robinhood. "
|
| 46 |
"They offer low fees, no minimums, and user-friendly platforms for beginners."
|
| 47 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
}
|
| 49 |
|
| 50 |
# Load persistent cache
|
|
@@ -73,13 +82,13 @@ except Exception as e:
|
|
| 73 |
logger.error(f"Error loading model/tokenizer: {e}")
|
| 74 |
raise RuntimeError(f"Failed to load model: {str(e)}")
|
| 75 |
|
| 76 |
-
# Updated prompt prefix with better instructions examples
|
| 77 |
prompt_prefix = (
|
| 78 |
"You are FinChat, a financial advisor. Always provide clear, step-by-step answers to the user's exact question. "
|
| 79 |
"Avoid vague or unrelated topics. Use a numbered list format where appropriate and explain each step.\n\n"
|
| 80 |
"Example 1:\n"
|
| 81 |
"Q: How can I start investing with $100 a month?\n"
|
| 82 |
-
"A: Here’s a step-by-step guide:\n"
|
| 83 |
"1. Open a brokerage account with a platform like Fidelity or Robinhood. They offer low fees and no minimums.\n"
|
| 84 |
"2. Deposit your $100 monthly. You can set up automatic transfers.\n"
|
| 85 |
"3. Choose a low-cost ETF like VOO, which tracks the S&P 500.\n"
|
|
@@ -97,10 +106,12 @@ def get_closest_cache_key(message, cache_keys, threshold=0.7):
|
|
| 97 |
matches = difflib.get_close_matches(message, cache_keys, n=1, cutoff=threshold)
|
| 98 |
return matches[0] if matches else None
|
| 99 |
|
| 100 |
-
# Define chat function with
|
| 101 |
def chat_with_model(user_input, history=None):
|
| 102 |
try:
|
|
|
|
| 103 |
logger.info(f"Processing user input: {user_input}")
|
|
|
|
| 104 |
cache_key = user_input.lower().strip()
|
| 105 |
cache_keys = list(response_cache.keys())
|
| 106 |
closest_key = cache_key if cache_key in response_cache else get_closest_cache_key(cache_key, cache_keys)
|
|
@@ -111,6 +122,8 @@ def chat_with_model(user_input, history=None):
|
|
| 111 |
history = history or []
|
| 112 |
history.append({"role": "user", "content": user_input})
|
| 113 |
history.append({"role": "assistant", "content": response})
|
|
|
|
|
|
|
| 114 |
return response, history
|
| 115 |
|
| 116 |
if len(user_input.strip()) <= 5:
|
|
@@ -120,22 +133,26 @@ def chat_with_model(user_input, history=None):
|
|
| 120 |
history = history or []
|
| 121 |
history.append({"role": "user", "content": user_input})
|
| 122 |
history.append({"role": "assistant", "content": response})
|
|
|
|
|
|
|
| 123 |
return response, history
|
| 124 |
|
| 125 |
full_prompt = prompt_prefix + user_input + "\nA:"
|
| 126 |
inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=512).to(device)
|
| 127 |
|
| 128 |
with torch.inference_mode():
|
|
|
|
| 129 |
outputs = model.generate(
|
| 130 |
**inputs,
|
| 131 |
-
max_new_tokens=
|
| 132 |
min_length=20,
|
| 133 |
-
do_sample=
|
| 134 |
-
temperature=0.5, # Lowered for more focused responses
|
| 135 |
-
top_p=0.9,
|
| 136 |
repetition_penalty=1.2,
|
| 137 |
pad_token_id=tokenizer.eos_token_id
|
| 138 |
)
|
|
|
|
|
|
|
|
|
|
| 139 |
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 140 |
response = response[len(full_prompt):].strip() if response.startswith(full_prompt) else response
|
| 141 |
logger.info(f"Chatbot response: {response}")
|
|
@@ -147,6 +164,8 @@ def chat_with_model(user_input, history=None):
|
|
| 147 |
history.append({"role": "user", "content": user_input})
|
| 148 |
history.append({"role": "assistant", "content": response})
|
| 149 |
torch.cuda.empty_cache()
|
|
|
|
|
|
|
| 150 |
return response, history
|
| 151 |
except Exception as e:
|
| 152 |
logger.error(f"Error generating response: {e}")
|
|
@@ -157,34 +176,7 @@ def chat_with_model(user_input, history=None):
|
|
| 157 |
history.append({"role": "assistant", "content": response})
|
| 158 |
return response, history
|
| 159 |
|
| 160 |
-
#
|
| 161 |
-
def log_feedback_up(history):
|
| 162 |
-
if history:
|
| 163 |
-
last_user = history[-2]['content']
|
| 164 |
-
last_assistant = history[-1]['content']
|
| 165 |
-
feedback = {"question": last_user, "response": last_assistant, "feedback": "up"}
|
| 166 |
-
try:
|
| 167 |
-
with open("feedback.json", "a") as f:
|
| 168 |
-
json.dump(feedback, f)
|
| 169 |
-
f.write("\n")
|
| 170 |
-
logger.info("Logged positive feedback")
|
| 171 |
-
except Exception as e:
|
| 172 |
-
logger.warning(f"Failed to log feedback: {e}")
|
| 173 |
-
|
| 174 |
-
def log_feedback_down(history):
|
| 175 |
-
if history:
|
| 176 |
-
last_user = history[-2]['content']
|
| 177 |
-
last_assistant = history[-1]['content']
|
| 178 |
-
feedback = {"question": last_user, "response": last_assistant, "feedback": "down"}
|
| 179 |
-
try:
|
| 180 |
-
with open("feedback.json", "a") as f:
|
| 181 |
-
json.dump(feedback, f)
|
| 182 |
-
f.write("\n")
|
| 183 |
-
logger.info("Logged negative feedback")
|
| 184 |
-
except Exception as e:
|
| 185 |
-
logger.warning(f"Failed to log feedback: {e}")
|
| 186 |
-
|
| 187 |
-
# Create Gradio interface with feedback buttons
|
| 188 |
with gr.Blocks(
|
| 189 |
title="FinChat: An LLM based on distilgpt2 model",
|
| 190 |
css=".feedback {display: flex; gap: 10px; justify-content: center; margin-top: 10px;}"
|
|
@@ -201,12 +193,6 @@ with gr.Blocks(
|
|
| 201 |
submit = gr.Button("Send")
|
| 202 |
clear = gr.Button("Clear")
|
| 203 |
|
| 204 |
-
# Feedback section
|
| 205 |
-
gr.Markdown("**Was this helpful?**")
|
| 206 |
-
with gr.Row(elem_classes="feedback"):
|
| 207 |
-
thumbs_up = gr.Button("👍")
|
| 208 |
-
thumbs_down = gr.Button("👎")
|
| 209 |
-
|
| 210 |
def submit_message(user_input, history):
|
| 211 |
response, updated_history = chat_with_model(user_input, history)
|
| 212 |
return "", updated_history # Clear input, update chatbot
|
|
@@ -220,15 +206,12 @@ with gr.Blocks(
|
|
| 220 |
fn=lambda: ("", []), # Clear input and chatbot
|
| 221 |
outputs=[msg, chatbot]
|
| 222 |
)
|
| 223 |
-
|
| 224 |
-
thumbs_up.click(fn=log_feedback_up, inputs=[chatbot], outputs=None)
|
| 225 |
-
thumbs_down.click(fn=log_feedback_down, inputs=[chatbot], outputs=None)
|
| 226 |
|
| 227 |
# Launch interface (conditional for Spaces)
|
| 228 |
if __name__ == "__main__" and not os.getenv("HF_SPACE"):
|
| 229 |
logger.info("Launching Gradio interface locally")
|
| 230 |
try:
|
| 231 |
-
|
| 232 |
except Exception as e:
|
| 233 |
logger.error(f"Error launching interface: {e}")
|
| 234 |
raise
|
|
|
|
| 1 |
import logging
|
| 2 |
import os
|
| 3 |
+
import time # Added for timing logs
|
| 4 |
import torch
|
| 5 |
import gradio as gr
|
| 6 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
| 46 |
"The best places to open a brokerage account include Vanguard, Fidelity, Charles Schwab, and Robinhood. "
|
| 47 |
"They offer low fees, no minimums, and user-friendly platforms for beginners."
|
| 48 |
),
|
| 49 |
+
"what is dollar-cost averaging?": (
|
| 50 |
+
"Dollar-cost averaging is investing a fixed amount regularly (e.g., $100 monthly) in ETFs, "
|
| 51 |
+
"reducing risk by spreading purchases over time."
|
| 52 |
+
),
|
| 53 |
+
"how much should i invest?": (
|
| 54 |
+
"Invest what you can afford after expenses and an emergency fund. Start with $100-$500 monthly "
|
| 55 |
+
"in ETFs like VOO using dollar-cost averaging. Consult a financial planner."
|
| 56 |
+
),
|
| 57 |
}
|
| 58 |
|
| 59 |
# Load persistent cache
|
|
|
|
| 82 |
logger.error(f"Error loading model/tokenizer: {e}")
|
| 83 |
raise RuntimeError(f"Failed to load model: {str(e)}")
|
| 84 |
|
| 85 |
+
# Updated prompt prefix with better instructions and examples
|
| 86 |
prompt_prefix = (
|
| 87 |
"You are FinChat, a financial advisor. Always provide clear, step-by-step answers to the user's exact question. "
|
| 88 |
"Avoid vague or unrelated topics. Use a numbered list format where appropriate and explain each step.\n\n"
|
| 89 |
"Example 1:\n"
|
| 90 |
"Q: How can I start investing with $100 a month?\n"
|
| 91 |
+
"A: Here’s a step-by point-by-step guide:\n"
|
| 92 |
"1. Open a brokerage account with a platform like Fidelity or Robinhood. They offer low fees and no minimums.\n"
|
| 93 |
"2. Deposit your $100 monthly. You can set up automatic transfers.\n"
|
| 94 |
"3. Choose a low-cost ETF like VOO, which tracks the S&P 500.\n"
|
|
|
|
| 106 |
matches = difflib.get_close_matches(message, cache_keys, n=1, cutoff=threshold)
|
| 107 |
return matches[0] if matches else None
|
| 108 |
|
| 109 |
+
# Define chat function with optimized generation parameters
|
| 110 |
def chat_with_model(user_input, history=None):
|
| 111 |
try:
|
| 112 |
+
start_time = time.time() # Start timing
|
| 113 |
logger.info(f"Processing user input: {user_input}")
|
| 114 |
+
|
| 115 |
cache_key = user_input.lower().strip()
|
| 116 |
cache_keys = list(response_cache.keys())
|
| 117 |
closest_key = cache_key if cache_key in response_cache else get_closest_cache_key(cache_key, cache_keys)
|
|
|
|
| 122 |
history = history or []
|
| 123 |
history.append({"role": "user", "content": user_input})
|
| 124 |
history.append({"role": "assistant", "content": response})
|
| 125 |
+
end_time = time.time()
|
| 126 |
+
logger.info(f"Response time: {end_time - start_time:.2f} seconds")
|
| 127 |
return response, history
|
| 128 |
|
| 129 |
if len(user_input.strip()) <= 5:
|
|
|
|
| 133 |
history = history or []
|
| 134 |
history.append({"role": "user", "content": user_input})
|
| 135 |
history.append({"role": "assistant", "content": response})
|
| 136 |
+
end_time = time.time()
|
| 137 |
+
logger.info(f"Response time: {end_time - start_time:.2f} seconds")
|
| 138 |
return response, history
|
| 139 |
|
| 140 |
full_prompt = prompt_prefix + user_input + "\nA:"
|
| 141 |
inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=512).to(device)
|
| 142 |
|
| 143 |
with torch.inference_mode():
|
| 144 |
+
gen_start_time = time.time() # Start generation timing
|
| 145 |
outputs = model.generate(
|
| 146 |
**inputs,
|
| 147 |
+
max_new_tokens=75, # Reduced for faster generation
|
| 148 |
min_length=20,
|
| 149 |
+
do_sample=False, # Use greedy decoding for speed
|
|
|
|
|
|
|
| 150 |
repetition_penalty=1.2,
|
| 151 |
pad_token_id=tokenizer.eos_token_id
|
| 152 |
)
|
| 153 |
+
gen_end_time = time.time()
|
| 154 |
+
logger.info(f"Generation time: {gen_end_time - gen_start_time:.2f} seconds")
|
| 155 |
+
|
| 156 |
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 157 |
response = response[len(full_prompt):].strip() if response.startswith(full_prompt) else response
|
| 158 |
logger.info(f"Chatbot response: {response}")
|
|
|
|
| 164 |
history.append({"role": "user", "content": user_input})
|
| 165 |
history.append({"role": "assistant", "content": response})
|
| 166 |
torch.cuda.empty_cache()
|
| 167 |
+
end_time = time.time()
|
| 168 |
+
logger.info(f"Total response time: {end_time - start_time:.2f} seconds")
|
| 169 |
return response, history
|
| 170 |
except Exception as e:
|
| 171 |
logger.error(f"Error generating response: {e}")
|
|
|
|
| 176 |
history.append({"role": "assistant", "content": response})
|
| 177 |
return response, history
|
| 178 |
|
| 179 |
+
# Create Gradio interface
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
with gr.Blocks(
|
| 181 |
title="FinChat: An LLM based on distilgpt2 model",
|
| 182 |
css=".feedback {display: flex; gap: 10px; justify-content: center; margin-top: 10px;}"
|
|
|
|
| 193 |
submit = gr.Button("Send")
|
| 194 |
clear = gr.Button("Clear")
|
| 195 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
def submit_message(user_input, history):
|
| 197 |
response, updated_history = chat_with_model(user_input, history)
|
| 198 |
return "", updated_history # Clear input, update chatbot
|
|
|
|
| 206 |
fn=lambda: ("", []), # Clear input and chatbot
|
| 207 |
outputs=[msg, chatbot]
|
| 208 |
)
|
|
|
|
|
|
|
|
|
|
| 209 |
|
| 210 |
# Launch interface (conditional for Spaces)
|
| 211 |
if __name__ == "__main__" and not os.getenv("HF_SPACE"):
|
| 212 |
logger.info("Launching Gradio interface locally")
|
| 213 |
try:
|
| 214 |
+
interface.launch(share=False, debug=True)
|
| 215 |
except Exception as e:
|
| 216 |
logger.error(f"Error launching interface: {e}")
|
| 217 |
raise
|