Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -11,6 +11,9 @@ from huggingface_hub import hf_hub_download, login
|
|
| 11 |
import logging
|
| 12 |
import pandas as pd
|
| 13 |
import torch
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
# Set up logging
|
| 16 |
logging.basicConfig(level=logging.INFO)
|
|
@@ -27,14 +30,14 @@ except ModuleNotFoundError:
|
|
| 27 |
else:
|
| 28 |
logger.info("Installing llama-cpp-python without additional flags.")
|
| 29 |
subprocess.check_call([sys.executable, "-m", "pip", "install", "llama-cpp-python", "--force-reinstall", "--upgrade", "--no-cache-dir"])
|
| 30 |
-
from llama_cpp import Llama
|
| 31 |
|
| 32 |
# Install yfinance if not present (for CAGR calculations)
|
| 33 |
try:
|
| 34 |
import yfinance as yf
|
| 35 |
except ModuleNotFoundError:
|
| 36 |
subprocess.check_call([sys.executable, "-m", "pip", "install", "yfinance"])
|
| 37 |
-
import yfinance as yf
|
| 38 |
|
| 39 |
# Import pandas for handling DataFrame column structures
|
| 40 |
import pandas as pd
|
|
@@ -46,9 +49,9 @@ try:
|
|
| 46 |
import io
|
| 47 |
except ModuleNotFoundError:
|
| 48 |
subprocess.check_call([sys.executable, "-m", "pip", "install", "matplotlib", "pillow"])
|
| 49 |
-
import matplotlib.pyplot as plt
|
| 50 |
-
from PIL import Image
|
| 51 |
-
import io
|
| 52 |
|
| 53 |
MAX_MAX_NEW_TOKENS = 512
|
| 54 |
DEFAULT_MAX_NEW_TOKENS = 512
|
|
@@ -72,12 +75,12 @@ try:
|
|
| 72 |
llm = Llama(
|
| 73 |
model_path=model_path,
|
| 74 |
n_ctx=1024,
|
| 75 |
-
n_batch=1024,
|
| 76 |
n_threads=multiprocessing.cpu_count(),
|
| 77 |
n_gpu_layers=n_gpu_layers,
|
| 78 |
-
chat_format="chatml"
|
| 79 |
)
|
| 80 |
-
logger.info(f"Model loaded successfully with n_gpu_layers={n_gpu_layers}.")
|
| 81 |
# Warm up the model for faster initial inference
|
| 82 |
llm("Warm-up prompt", max_tokens=1, echo=False)
|
| 83 |
logger.info("Model warm-up completed.")
|
|
@@ -90,323 +93,147 @@ atexit.register(llm.close)
|
|
| 90 |
|
| 91 |
DEFAULT_SYSTEM_PROMPT = """You are FinChat, a knowledgeable AI assistant specializing in investing and finance. Provide accurate, helpful, reasoned, and concise answers to investing questions. Always base responses on reliable information and advise users to consult professionals for personalized advice.
|
| 92 |
Always respond exclusively in English. Use bullet points for clarity.
|
|
|
|
| 93 |
Example:
|
| 94 |
-
User: average return for
|
| 95 |
-
Assistant:
|
| 96 |
-
-
|
| 97 |
- Represents average annual return with compounding
|
| 98 |
-
- Past performance not indicative of future results
|
| 99 |
-
- Consult a financial advisor"""
|
| 100 |
|
| 101 |
-
#
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
|
|
|
| 110 |
def generate(
|
| 111 |
message: str,
|
| 112 |
-
|
| 113 |
-
system_prompt: str
|
| 114 |
-
max_new_tokens: int
|
| 115 |
-
temperature: float
|
| 116 |
-
top_p: float
|
| 117 |
-
top_k: int
|
| 118 |
-
repetition_penalty: float = 1.2,
|
| 119 |
) -> Iterator[str]:
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
-
|
| 134 |
-
- Past performance not indicative of future results. Consult a financial advisor."""
|
| 135 |
-
logger.info("Quick response for 'what is cagr' generated.")
|
| 136 |
-
yield response
|
| 137 |
-
return
|
| 138 |
-
|
| 139 |
-
# Check for CAGR/average return queries (use re.search for flexible matching)
|
| 140 |
-
match = re.search(r'(?:average return|cagr) for ([\w\s,]+(?:and [\w\s,]+)?) between (\d{4}) and (\d{4})', lower_message)
|
| 141 |
-
if match:
|
| 142 |
-
tickers_str, start_year, end_year = match.groups()
|
| 143 |
-
tickers = [t.strip().upper() for t in re.split(r',|\band\b', tickers_str) if t.strip()]
|
| 144 |
-
|
| 145 |
-
# Apply company-to-ticker mapping
|
| 146 |
-
for i in range(len(tickers)):
|
| 147 |
-
lower_ticker = tickers[i].lower()
|
| 148 |
-
if lower_ticker in COMPANY_TO_TICKER:
|
| 149 |
-
tickers[i] = COMPANY_TO_TICKER[lower_ticker]
|
| 150 |
-
|
| 151 |
-
responses = []
|
| 152 |
-
if int(end_year) <= int(start_year):
|
| 153 |
-
yield "The specified time period is invalid (end year must be after start year)."
|
| 154 |
return
|
| 155 |
-
|
| 156 |
-
for ticker in tickers:
|
| 157 |
-
try:
|
| 158 |
-
# Download data with adjusted close prices
|
| 159 |
-
data = yf.download(ticker, start=f"{start_year}-01-01", end=f"{end_year}-12-31", progress=False, auto_adjust=False)
|
| 160 |
-
# Handle potential MultiIndex columns in newer yfinance versions
|
| 161 |
-
if isinstance(data.columns, pd.MultiIndex):
|
| 162 |
-
data.columns = data.columns.droplevel(1)
|
| 163 |
-
if not data.empty:
|
| 164 |
-
# Check if 'Adj Close' column exists
|
| 165 |
-
if 'Adj Close' not in data.columns:
|
| 166 |
-
responses.append(f"- {ticker}: Error - Adjusted Close price data not available.")
|
| 167 |
-
logger.error(f"No 'Adj Close' column for {ticker}.")
|
| 168 |
-
continue
|
| 169 |
-
# Ensure data is not MultiIndex for single ticker (already handled)
|
| 170 |
-
initial = data['Adj Close'].iloc[0]
|
| 171 |
-
final = data['Adj Close'].iloc[-1]
|
| 172 |
-
start_date = data.index[0]
|
| 173 |
-
end_date = data.index[-1]
|
| 174 |
-
days = (end_date - start_date).days
|
| 175 |
-
years = days / 365.25
|
| 176 |
-
if years > 0 and pd.notna(initial) and pd.notna(final):
|
| 177 |
-
cagr = ((final / initial) ** (1 / years) - 1) * 100
|
| 178 |
-
responses.append(f"- {ticker}: ~{cagr:.2f}%")
|
| 179 |
-
else:
|
| 180 |
-
responses.append(f"- {ticker}: Invalid period or missing price data.")
|
| 181 |
-
else:
|
| 182 |
-
responses.append(f"- {ticker}: No historical data available between {start_year} and {end_year}.")
|
| 183 |
-
except Exception as e:
|
| 184 |
-
logger.error(f"Error calculating CAGR for {ticker}: {str(e)}")
|
| 185 |
-
responses.append(f"- {ticker}: Error calculating CAGR - {str(e)}")
|
| 186 |
-
|
| 187 |
-
full_response = f"CAGR for the requested stocks from {start_year} to {end_year}:\n" + "\n".join(responses) + "\n- Represents average annual returns with compounding\n- Past performance not indicative of future results\n- Consult a financial advisor"
|
| 188 |
-
full_response = re.sub(r'<\|(?:im_start|im_end|system|user|assistant)\|>|</s>|\[END\]', '', full_response).strip() # Clean any trailing tokens
|
| 189 |
-
|
| 190 |
-
# Estimate token count to ensure response fits within max_new_tokens
|
| 191 |
-
response_tokens = len(llm.tokenize(full_response.encode("utf-8"), add_bos=False))
|
| 192 |
-
if response_tokens > max_new_tokens:
|
| 193 |
-
logger.warning(f"CAGR response tokens ({response_tokens}) exceed max_new_tokens ({max_new_tokens}). Truncating to first complete sentence.")
|
| 194 |
-
sentence_endings = ['.', '!', '?']
|
| 195 |
-
first_sentence_end = min([full_response.find(ending) + 1 for ending in sentence_endings if full_response.find(ending) != -1], default=len(full_response))
|
| 196 |
-
full_response = full_response[:first_sentence_end] if first_sentence_end > 0 else "Response truncated due to length; please increase Max New Tokens."
|
| 197 |
-
|
| 198 |
-
logger.info("CAGR response generated.")
|
| 199 |
-
yield full_response
|
| 200 |
-
return
|
| 201 |
-
|
| 202 |
-
# Build conversation messages (limit history to last 3 for speed)
|
| 203 |
-
conversation = [{"role": "system", "content": system_prompt}]
|
| 204 |
-
for msg in chat_history[-3:]: # Reduced from 5 to 3 for faster processing
|
| 205 |
-
if msg["role"] == "user":
|
| 206 |
-
conversation.append({"role": "user", "content": msg["content"]})
|
| 207 |
-
elif msg["role"] == "assistant":
|
| 208 |
-
conversation.append({"role": "assistant", "content": msg["content"]})
|
| 209 |
-
conversation.append({"role": "user", "content": message})
|
| 210 |
-
|
| 211 |
-
# Approximate token length check and truncate if necessary
|
| 212 |
-
prompt_text = "\n".join(d["content"] for d in conversation)
|
| 213 |
-
input_tokens = llm.tokenize(prompt_text.encode("utf-8"), add_bos=False)
|
| 214 |
-
while len(input_tokens) > MAX_INPUT_TOKEN_LENGTH:
|
| 215 |
-
logger.warning(f"Input tokens ({len(input_tokens)}) exceed limit ({MAX_INPUT_TOKEN_LENGTH}). Truncating history.")
|
| 216 |
-
if len(conversation) > 2: # Preserve system prompt and current user message
|
| 217 |
-
conversation.pop(1) # Remove oldest user/assistant pair
|
| 218 |
-
prompt_text = "\n".join(d["content"] for d in conversation)
|
| 219 |
-
input_tokens = llm.tokenize(prompt_text.encode("utf-8"), add_bos=False)
|
| 220 |
else:
|
| 221 |
-
yield "
|
| 222 |
return
|
| 223 |
|
| 224 |
-
#
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
if "content" in delta and delta["content"] is not None:
|
| 242 |
-
# Clean the chunk by removing ChatML tokens or similar
|
| 243 |
-
cleaned_chunk = re.sub(r'<\|(?:im_start|im_end|system|user|assistant)\|>|</s>|\[END\]', '', delta["content"])
|
| 244 |
-
if not cleaned_chunk:
|
| 245 |
-
continue
|
| 246 |
-
sentence_buffer += cleaned_chunk
|
| 247 |
-
response += cleaned_chunk
|
| 248 |
-
# Approximate token count for the chunk
|
| 249 |
-
chunk_tokens = len(llm.tokenize(cleaned_chunk.encode("utf-8"), add_bos=False))
|
| 250 |
-
token_count += chunk_tokens
|
| 251 |
-
# Check for sentence boundary
|
| 252 |
-
if any(sentence_buffer.strip().endswith(ending) for ending in sentence_endings):
|
| 253 |
-
yield response
|
| 254 |
-
sentence_buffer = "" # Clear buffer after yielding a complete sentence
|
| 255 |
-
# Removed early truncation to allow full token utilization
|
| 256 |
-
|
| 257 |
-
if chunk["choices"][0]["finish_reason"] is not None:
|
| 258 |
-
# Yield any remaining complete sentence in the buffer
|
| 259 |
-
if sentence_buffer.strip():
|
| 260 |
-
last_sentence_end = max([sentence_buffer.rfind(ending) for ending in sentence_endings if sentence_buffer.rfind(ending) != -1], default=-1)
|
| 261 |
-
if last_sentence_end != -1:
|
| 262 |
-
response = response[:response.rfind(sentence_buffer) + last_sentence_end + 1]
|
| 263 |
-
yield response
|
| 264 |
-
else:
|
| 265 |
-
yield response
|
| 266 |
-
else:
|
| 267 |
-
yield response
|
| 268 |
-
break
|
| 269 |
-
logger.info("Response generation completed.")
|
| 270 |
-
except ValueError as ve:
|
| 271 |
-
if "exceed context window" in str(ve):
|
| 272 |
-
yield "Error: Prompt too long for context window. Please try a shorter query or clear history."
|
| 273 |
else:
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
return "", None
|
| 283 |
-
# Convert to DataFrame if needed
|
| 284 |
-
if not isinstance(df, pd.DataFrame):
|
| 285 |
-
df = pd.DataFrame(df, columns=["Ticker", "Shares", "Avg Cost", "Current Price"])
|
| 286 |
-
df = df.dropna(subset=["Ticker"])
|
| 287 |
-
portfolio = {}
|
| 288 |
-
for _, row in df.iterrows():
|
| 289 |
-
ticker = row["Ticker"].upper() if pd.notna(row["Ticker"]) else None
|
| 290 |
-
if not ticker:
|
| 291 |
-
continue
|
| 292 |
-
shares = float(row["Shares"]) if pd.notna(row["Shares"]) else 0
|
| 293 |
-
cost = float(row["Avg Cost"]) if pd.notna(row["Avg Cost"]) else 0
|
| 294 |
-
price = float(row["Current Price"]) if pd.notna(row["Current Price"]) else 0
|
| 295 |
-
value = shares * price
|
| 296 |
-
portfolio[ticker] = {'shares': shares, 'cost': cost, 'price': price, 'value': value}
|
| 297 |
-
if not portfolio:
|
| 298 |
-
return "", None
|
| 299 |
-
total_value_now = sum(v['value'] for v in portfolio.values())
|
| 300 |
-
allocations = {k: v['value'] / total_value_now for k, v in portfolio.items()} if total_value_now > 0 else {}
|
| 301 |
-
fig_alloc, ax_alloc = plt.subplots()
|
| 302 |
-
ax_alloc.pie(allocations.values(), labels=allocations.keys(), autopct='%1.1f%%')
|
| 303 |
-
ax_alloc.set_title('Portfolio Allocation')
|
| 304 |
-
buf_alloc = io.BytesIO()
|
| 305 |
-
fig_alloc.savefig(buf_alloc, format='png')
|
| 306 |
-
buf_alloc.seek(0)
|
| 307 |
-
chart_alloc = Image.open(buf_alloc)
|
| 308 |
-
plt.close(fig_alloc) # Close the figure to free memory
|
| 309 |
-
|
| 310 |
-
def project_value(value, years, rate):
|
| 311 |
-
return value * (1 + rate / 100) ** years
|
| 312 |
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
|
|
|
| 317 |
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
|
|
|
|
|
|
|
|
|
| 323 |
)
|
| 324 |
-
return data_str, chart_alloc
|
| 325 |
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
df = pd.DataFrame(df, columns=["Ticker", "Shares", "Avg Cost", "Current Price"])
|
| 332 |
-
for i in df.index:
|
| 333 |
-
ticker = df.at[i, "Ticker"]
|
| 334 |
-
if pd.notna(ticker) and ticker.strip():
|
| 335 |
-
try:
|
| 336 |
-
price = yf.Ticker(ticker.upper()).info.get('currentPrice', None)
|
| 337 |
-
if price is not None:
|
| 338 |
-
df.at[i, "Current Price"] = price
|
| 339 |
-
except Exception as e:
|
| 340 |
-
logger.warning(f"Failed to fetch price for {ticker}: {str(e)}")
|
| 341 |
-
return df
|
| 342 |
|
| 343 |
-
# Gradio interface setup
|
| 344 |
-
with gr.Blocks(theme=themes.
|
| 345 |
gr.Markdown(DESCRIPTION)
|
| 346 |
-
chatbot = gr.Chatbot(label="FinChat", type="messages")
|
| 347 |
-
msg = gr.Textbox(label="Ask a finance question", placeholder="e.g., 'What is CAGR?' or 'Average return for AAPL between 2010 and 2020'", info="Enter your query here. Portfolio data will be appended if provided.")
|
| 348 |
-
with gr.Row():
|
| 349 |
-
submit = gr.Button("Submit", variant="primary")
|
| 350 |
-
clear = gr.Button("Clear")
|
| 351 |
-
gr.Examples(
|
| 352 |
-
examples=["What is CAGR?", "Average return for AAPL between 2010 and 2020", "Hi", "Explain compound interest"],
|
| 353 |
-
inputs=msg,
|
| 354 |
-
label="Example Queries"
|
| 355 |
-
)
|
| 356 |
-
with gr.Accordion("Enter Portfolio for Projections", open=False):
|
| 357 |
-
portfolio_df = gr.Dataframe(
|
| 358 |
-
headers=["Ticker", "Shares", "Avg Cost", "Current Price"],
|
| 359 |
-
datatype=["str", "number", "number", "number"],
|
| 360 |
-
row_count=3,
|
| 361 |
-
col_count=(4, "fixed"),
|
| 362 |
-
label="Portfolio Data",
|
| 363 |
-
interactive=True
|
| 364 |
-
)
|
| 365 |
-
gr.Markdown("Enter your stocks here. You can add more rows by editing the table.")
|
| 366 |
-
fetch_button = gr.Button("Fetch Current Prices", variant="secondary")
|
| 367 |
-
fetch_button.click(fetch_current_prices, inputs=portfolio_df, outputs=portfolio_df)
|
| 368 |
-
growth_rate = gr.Slider(minimum=5, maximum=50, step=5, value=10, label="Annual Growth Rate (%)", interactive=True, info="Select the assumed annual growth rate for projections.")
|
| 369 |
-
growth_rate_label = gr.Markdown("**Selected Growth Rate: 10%**")
|
| 370 |
-
with gr.Accordion("Advanced Settings", open=False):
|
| 371 |
-
system_prompt = gr.Textbox(label="System Prompt", value=DEFAULT_SYSTEM_PROMPT, lines=6, info="Customize the AI's system prompt.")
|
| 372 |
-
temperature = gr.Slider(label="Temperature", value=0.6, minimum=0.0, maximum=1.0, step=0.05, info="Controls randomness: lower is more deterministic.")
|
| 373 |
-
top_p = gr.Slider(label="Top P", value=0.9, minimum=0.0, maximum=1.0, step=0.05, info="Nucleus sampling: higher includes more diverse tokens.")
|
| 374 |
-
top_k = gr.Slider(label="Top K", value=50, minimum=1, maximum=100, step=1, info="Top-K sampling: limits to top K tokens.")
|
| 375 |
-
repetition_penalty = gr.Slider(label="Repetition Penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05, info="Penalizes repeated tokens.")
|
| 376 |
-
max_new_tokens = gr.Slider(label="Max New Tokens", value=DEFAULT_MAX_NEW_TOKENS, minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, info="Maximum length of generated response.")
|
| 377 |
gr.Markdown(LICENSE)
|
| 378 |
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
return "", history
|
| 385 |
-
return "", history + [{"role": "user", "content": message}]
|
| 386 |
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
history[-1]["content"] = message
|
| 395 |
-
history.append({"role": "assistant", "content": ""})
|
| 396 |
-
for new_text in generate(message, history[:-1], sys_prompt, mnt, temp, tp, tk, rp):
|
| 397 |
-
history[-1]["content"] = new_text
|
| 398 |
-
yield history, f"**Selected Growth Rate: {growth_rate}%**"
|
| 399 |
-
if chart_alloc:
|
| 400 |
-
history.append({"role": "assistant", "content": "", "image": chart_alloc})
|
| 401 |
-
yield history, f"**Selected Growth Rate: {growth_rate}%**"
|
| 402 |
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
bot, [chatbot, system_prompt, temperature, top_p, top_k, repetition_penalty, max_new_tokens, portfolio_df, growth_rate], [chatbot, growth_rate_label]
|
| 406 |
-
)
|
| 407 |
-
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
|
| 408 |
-
bot, [chatbot, system_prompt, temperature, top_p, top_k, repetition_penalty, max_new_tokens, portfolio_df, growth_rate], [chatbot, growth_rate_label]
|
| 409 |
)
|
| 410 |
-
clear.click(lambda:
|
| 411 |
|
| 412 |
-
demo.
|
|
|
|
| 11 |
import logging
|
| 12 |
import pandas as pd
|
| 13 |
import torch
|
| 14 |
+
import yfinance as yf
|
| 15 |
+
from datetime import datetime, timedelta
|
| 16 |
+
from math import sqrt
|
| 17 |
|
| 18 |
# Set up logging
|
| 19 |
logging.basicConfig(level=logging.INFO)
|
|
|
|
| 30 |
else:
|
| 31 |
logger.info("Installing llama-cpp-python without additional flags.")
|
| 32 |
subprocess.check_call([sys.executable, "-m", "pip", "install", "llama-cpp-python", "--force-reinstall", "--upgrade", "--no-cache-dir"])
|
| 33 |
+
from llama_cpp import Llama
|
| 34 |
|
| 35 |
# Install yfinance if not present (for CAGR calculations)
|
| 36 |
try:
|
| 37 |
import yfinance as yf
|
| 38 |
except ModuleNotFoundError:
|
| 39 |
subprocess.check_call([sys.executable, "-m", "pip", "install", "yfinance"])
|
| 40 |
+
import yfinance as yf
|
| 41 |
|
| 42 |
# Import pandas for handling DataFrame column structures
|
| 43 |
import pandas as pd
|
|
|
|
| 49 |
import io
|
| 50 |
except ModuleNotFoundError:
|
| 51 |
subprocess.check_call([sys.executable, "-m", "pip", "install", "matplotlib", "pillow"])
|
| 52 |
+
import matplotlib.pyplot as plt
|
| 53 |
+
from PIL import Image
|
| 54 |
+
import io
|
| 55 |
|
| 56 |
MAX_MAX_NEW_TOKENS = 512
|
| 57 |
DEFAULT_MAX_NEW_TOKENS = 512
|
|
|
|
| 75 |
llm = Llama(
|
| 76 |
model_path=model_path,
|
| 77 |
n_ctx=1024,
|
| 78 |
+
n_batch=1024, # Increased for faster processing
|
| 79 |
n_threads=multiprocessing.cpu_count(),
|
| 80 |
n_gpu_layers=n_gpu_layers,
|
| 81 |
+
chat_format="chatml" # Phi-2 uses ChatML format in llama.cpp
|
| 82 |
)
|
| 83 |
+
logger.info(f"Model loaded successfully with n_gpu_layers= {n_gpu_layers}.")
|
| 84 |
# Warm up the model for faster initial inference
|
| 85 |
llm("Warm-up prompt", max_tokens=1, echo=False)
|
| 86 |
logger.info("Model warm-up completed.")
|
|
|
|
| 93 |
|
| 94 |
DEFAULT_SYSTEM_PROMPT = """You are FinChat, a knowledgeable AI assistant specializing in investing and finance. Provide accurate, helpful, reasoned, and concise answers to investing questions. Always base responses on reliable information and advise users to consult professionals for personalized advice.
|
| 95 |
Always respond exclusively in English. Use bullet points for clarity.
|
| 96 |
+
Do not substitute or alter stock symbols provided in the user's query. Always use the exact tickers mentioned.
|
| 97 |
Example:
|
| 98 |
+
User: average return for AAPL between 2010 and 2020
|
| 99 |
+
Assistant:
|
| 100 |
+
- AAPL CAGR (2010-2020): ~27.24%
|
| 101 |
- Represents average annual return with compounding
|
| 102 |
+
- Past performance is not indicative of future results."""
|
|
|
|
| 103 |
|
| 104 |
+
# Function to calculate CAGR using yfinance
|
| 105 |
+
def calculate_cagr(ticker, start_date, end_date):
|
| 106 |
+
try:
|
| 107 |
+
data = yf.download(ticker, start=start_date, end=end_date)
|
| 108 |
+
if data.empty:
|
| 109 |
+
return None
|
| 110 |
+
start_price = data['Adj Close'].iloc[0]
|
| 111 |
+
end_price = data['Adj Close'].iloc[-1]
|
| 112 |
+
num_years = (data.index[-1] - data.index[0]).days / 365.25
|
| 113 |
+
cagr = (end_price / start_price) ** (1 / num_years) - 1
|
| 114 |
+
return cagr * 100 # Return as percentage
|
| 115 |
+
except Exception as e:
|
| 116 |
+
logger.error(f"Error calculating CAGR for {ticker}: {str(e)}")
|
| 117 |
+
return None
|
| 118 |
+
|
| 119 |
+
# New function to calculate risk metrics using yfinance
|
| 120 |
+
def calculate_risk_metrics(ticker, years=5):
|
| 121 |
+
try:
|
| 122 |
+
end_date = datetime.now().strftime('%Y-%m-%d')
|
| 123 |
+
start_date = (datetime.now() - timedelta(days=365 * years)).strftime('%Y-%m-%d')
|
| 124 |
+
data = yf.download(ticker, start=start_date, end=end_date)
|
| 125 |
+
if data.empty:
|
| 126 |
+
return None, None
|
| 127 |
+
returns = data['Adj Close'].pct_change().dropna()
|
| 128 |
+
volatility = returns.std() * sqrt(252) * 100 # Annualized volatility in percent
|
| 129 |
+
mean_return = returns.mean() * 252 # Annualized mean return
|
| 130 |
+
risk_free_rate = 0.02 # Assumed risk-free rate (e.g., 2%)
|
| 131 |
+
sharpe = (mean_return - risk_free_rate) / (volatility / 100) # Sharpe ratio
|
| 132 |
+
return volatility, sharpe
|
| 133 |
+
except Exception as e:
|
| 134 |
+
logger.error(f"Error calculating risk metrics for {ticker}: {str(e)}")
|
| 135 |
+
return None, None
|
| 136 |
|
| 137 |
+
# Assuming the generate function handles the chat logic (extended to include risk comparison)
|
| 138 |
def generate(
|
| 139 |
message: str,
|
| 140 |
+
history: list[tuple[str, str]],
|
| 141 |
+
system_prompt: str,
|
| 142 |
+
max_new_tokens: int,
|
| 143 |
+
temperature: float,
|
| 144 |
+
top_p: float,
|
| 145 |
+
top_k: int,
|
|
|
|
| 146 |
) -> Iterator[str]:
|
| 147 |
+
if not system_prompt:
|
| 148 |
+
system_prompt = DEFAULT_SYSTEM_PROMPT
|
| 149 |
+
|
| 150 |
+
# Detect CAGR query
|
| 151 |
+
cagr_match = re.search(r'average return for (\w+) between (\d{4}) and (\d{4})', message.lower())
|
| 152 |
+
if cagr_match:
|
| 153 |
+
ticker = cagr_match.group(1).upper()
|
| 154 |
+
start_year = cagr_match.group(2)
|
| 155 |
+
end_year = cagr_match.group(3)
|
| 156 |
+
start_date = f"{start_year}-01-01"
|
| 157 |
+
end_date = f"{end_year}-12-31"
|
| 158 |
+
cagr = calculate_cagr(ticker, start_date, end_date)
|
| 159 |
+
if cagr is not None:
|
| 160 |
+
yield f"- {ticker} CAGR ({start_year}-{end_year}): ~{cagr:.2f}%\n- Represents average annual return with compounding\n- Past performance is not indicative of future results.\n- Consult a financial advisor for personalized advice."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
else:
|
| 163 |
+
yield "Unable to calculate CAGR for the specified period."
|
| 164 |
return
|
| 165 |
|
| 166 |
+
# Detect risk comparison query
|
| 167 |
+
risk_match = re.search(r'which stock is riskier (\w+) or (\w+)', message.lower())
|
| 168 |
+
if risk_match:
|
| 169 |
+
ticker1 = risk_match.group(1).upper()
|
| 170 |
+
ticker2 = risk_match.group(2).upper()
|
| 171 |
+
vol1, sharpe1 = calculate_risk_metrics(ticker1)
|
| 172 |
+
vol2, sharpe2 = calculate_risk_metrics(ticker2)
|
| 173 |
+
if vol1 is None or vol2 is None:
|
| 174 |
+
yield "Unable to fetch risk metrics for one or both tickers."
|
| 175 |
+
return
|
| 176 |
+
if vol1 > vol2:
|
| 177 |
+
riskier = ticker1
|
| 178 |
+
less_risky = ticker2
|
| 179 |
+
higher_vol = vol1
|
| 180 |
+
lower_vol = vol2
|
| 181 |
+
riskier_sharpe = sharpe1
|
| 182 |
+
less_sharpe = sharpe2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
else:
|
| 184 |
+
riskier = ticker2
|
| 185 |
+
less_risky = ticker1
|
| 186 |
+
higher_vol = vol2
|
| 187 |
+
lower_vol = vol1
|
| 188 |
+
riskier_sharpe = sharpe2
|
| 189 |
+
less_sharpe = sharpe1
|
| 190 |
+
yield f"- {riskier} is riskier compared to {less_risky}.\n- It has a higher annualized standard deviation ({higher_vol:.2f}% vs {lower_vol:.2f}%) and a lower Sharpe ratio ({riskier_sharpe:.2f} vs {less_sharpe:.2f}), indicating greater volatility and potentially lower risk-adjusted returns.\n- Calculations based on the past 5 years of data.\n- Past performance is not indicative of future results. Consult a financial advisor for personalized advice."
|
| 191 |
+
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
|
| 193 |
+
# For other queries, fall back to LLM generation
|
| 194 |
+
conversation = [{"role": "system", "content": system_prompt}]
|
| 195 |
+
for user, assistant in history:
|
| 196 |
+
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
|
| 197 |
+
conversation.append({"role": "user", "content": message})
|
| 198 |
|
| 199 |
+
# Generate response using LLM (streamed)
|
| 200 |
+
response = llm.create_chat_completion(
|
| 201 |
+
messages=conversation,
|
| 202 |
+
max_tokens=max_new_tokens,
|
| 203 |
+
temperature=temperature,
|
| 204 |
+
top_p=top_p,
|
| 205 |
+
top_k=top_k,
|
| 206 |
+
stream=True
|
| 207 |
)
|
|
|
|
| 208 |
|
| 209 |
+
partial_text = ""
|
| 210 |
+
for chunk in response:
|
| 211 |
+
if "content" in chunk["choices"][0]["delta"]:
|
| 212 |
+
partial_text += chunk["choices"][0]["delta"]["content"]
|
| 213 |
+
yield partial_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
|
| 215 |
+
# Gradio interface setup (assuming this is part of the original code)
|
| 216 |
+
with gr.Blocks(theme=themes.Default()) as demo:
|
| 217 |
gr.Markdown(DESCRIPTION)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
gr.Markdown(LICENSE)
|
| 219 |
|
| 220 |
+
chatbot = gr.Chatbot()
|
| 221 |
+
msg = gr.Textbox(label="Enter your question")
|
| 222 |
+
with gr.Row():
|
| 223 |
+
submit = gr.Button("Submit")
|
| 224 |
+
clear = gr.Button("Clear")
|
|
|
|
|
|
|
| 225 |
|
| 226 |
+
advanced = gr.Accordion("Advanced Settings", open=False)
|
| 227 |
+
with advanced:
|
| 228 |
+
system_prompt = gr.Textbox(label="System Prompt", value=DEFAULT_SYSTEM_PROMPT, lines=6)
|
| 229 |
+
max_new_tokens = gr.Slider(minimum=1, maximum=MAX_MAX_NEW_TOKENS, value=DEFAULT_MAX_NEW_TOKENS, step=1, label="Max New Tokens")
|
| 230 |
+
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, label="Temperature")
|
| 231 |
+
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.1, label="Top P")
|
| 232 |
+
top_k = gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top K")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
|
| 234 |
+
submit.click(generate, [msg, chatbot, system_prompt, max_new_tokens, temperature, top_p, top_k], chatbot, queue=False).then(
|
| 235 |
+
lambda: "", None, msg
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
)
|
| 237 |
+
clear.click(lambda: None, None, chatbot)
|
| 238 |
|
| 239 |
+
demo.launch()
|