AnilNiraula commited on
Commit
f395a8f
·
verified ·
1 Parent(s): 5bb1ea5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +131 -304
app.py CHANGED
@@ -11,6 +11,9 @@ from huggingface_hub import hf_hub_download, login
11
  import logging
12
  import pandas as pd
13
  import torch
 
 
 
14
 
15
  # Set up logging
16
  logging.basicConfig(level=logging.INFO)
@@ -27,14 +30,14 @@ except ModuleNotFoundError:
27
  else:
28
  logger.info("Installing llama-cpp-python without additional flags.")
29
  subprocess.check_call([sys.executable, "-m", "pip", "install", "llama-cpp-python", "--force-reinstall", "--upgrade", "--no-cache-dir"])
30
- from llama_cpp import Llama
31
 
32
  # Install yfinance if not present (for CAGR calculations)
33
  try:
34
  import yfinance as yf
35
  except ModuleNotFoundError:
36
  subprocess.check_call([sys.executable, "-m", "pip", "install", "yfinance"])
37
- import yfinance as yf
38
 
39
  # Import pandas for handling DataFrame column structures
40
  import pandas as pd
@@ -46,9 +49,9 @@ try:
46
  import io
47
  except ModuleNotFoundError:
48
  subprocess.check_call([sys.executable, "-m", "pip", "install", "matplotlib", "pillow"])
49
- import matplotlib.pyplot as plt
50
- from PIL import Image
51
- import io
52
 
53
  MAX_MAX_NEW_TOKENS = 512
54
  DEFAULT_MAX_NEW_TOKENS = 512
@@ -72,12 +75,12 @@ try:
72
  llm = Llama(
73
  model_path=model_path,
74
  n_ctx=1024,
75
- n_batch=1024, # Increased for faster processing
76
  n_threads=multiprocessing.cpu_count(),
77
  n_gpu_layers=n_gpu_layers,
78
- chat_format="chatml" # Phi-2 uses ChatML format in llama.cpp
79
  )
80
- logger.info(f"Model loaded successfully with n_gpu_layers={n_gpu_layers}.")
81
  # Warm up the model for faster initial inference
82
  llm("Warm-up prompt", max_tokens=1, echo=False)
83
  logger.info("Model warm-up completed.")
@@ -90,323 +93,147 @@ atexit.register(llm.close)
90
 
91
  DEFAULT_SYSTEM_PROMPT = """You are FinChat, a knowledgeable AI assistant specializing in investing and finance. Provide accurate, helpful, reasoned, and concise answers to investing questions. Always base responses on reliable information and advise users to consult professionals for personalized advice.
92
  Always respond exclusively in English. Use bullet points for clarity.
 
93
  Example:
94
- User: average return for TSLA between 2010 and 2020
95
- Assistant:
96
- - TSLA CAGR (2010-2020): ~63.01%
97
  - Represents average annual return with compounding
98
- - Past performance not indicative of future results
99
- - Consult a financial advisor"""
100
 
101
- # Company name to ticker mapping (expand as needed)
102
- COMPANY_TO_TICKER = {
103
- "opendoor": "OPEN",
104
- "tesla": "TSLA",
105
- "apple": "AAPL",
106
- "amazon": "AMZN",
107
- # Add more mappings for common companies
108
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
 
110
  def generate(
111
  message: str,
112
- chat_history: list[dict],
113
- system_prompt: str = DEFAULT_SYSTEM_PROMPT,
114
- max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
115
- temperature: float = 0.6,
116
- top_p: float = 0.9,
117
- top_k: int = 50,
118
- repetition_penalty: float = 1.2,
119
  ) -> Iterator[str]:
120
- logger.info(f"Generating response for message: {message}")
121
- lower_message = message.lower().strip()
122
-
123
- if lower_message in ["hi", "hello"]:
124
- response = "I'm FinChat, your financial advisor. Ask me anything finance-related!"
125
- logger.info("Quick response for 'hi'/'hello' generated.")
126
- yield response
127
- return
128
-
129
- if "what is cagr" in lower_message:
130
- response = """- CAGR stands for Compound Annual Growth Rate.
131
- - It measures the mean annual growth rate of an investment over a specified period longer than one year, accounting for compounding.
132
- - Formula: CAGR = (Ending Value / Beginning Value)^(1 / Number of Years) - 1
133
- - Useful for comparing investments over time.
134
- - Past performance not indicative of future results. Consult a financial advisor."""
135
- logger.info("Quick response for 'what is cagr' generated.")
136
- yield response
137
- return
138
-
139
- # Check for CAGR/average return queries (use re.search for flexible matching)
140
- match = re.search(r'(?:average return|cagr) for ([\w\s,]+(?:and [\w\s,]+)?) between (\d{4}) and (\d{4})', lower_message)
141
- if match:
142
- tickers_str, start_year, end_year = match.groups()
143
- tickers = [t.strip().upper() for t in re.split(r',|\band\b', tickers_str) if t.strip()]
144
-
145
- # Apply company-to-ticker mapping
146
- for i in range(len(tickers)):
147
- lower_ticker = tickers[i].lower()
148
- if lower_ticker in COMPANY_TO_TICKER:
149
- tickers[i] = COMPANY_TO_TICKER[lower_ticker]
150
-
151
- responses = []
152
- if int(end_year) <= int(start_year):
153
- yield "The specified time period is invalid (end year must be after start year)."
154
  return
155
-
156
- for ticker in tickers:
157
- try:
158
- # Download data with adjusted close prices
159
- data = yf.download(ticker, start=f"{start_year}-01-01", end=f"{end_year}-12-31", progress=False, auto_adjust=False)
160
- # Handle potential MultiIndex columns in newer yfinance versions
161
- if isinstance(data.columns, pd.MultiIndex):
162
- data.columns = data.columns.droplevel(1)
163
- if not data.empty:
164
- # Check if 'Adj Close' column exists
165
- if 'Adj Close' not in data.columns:
166
- responses.append(f"- {ticker}: Error - Adjusted Close price data not available.")
167
- logger.error(f"No 'Adj Close' column for {ticker}.")
168
- continue
169
- # Ensure data is not MultiIndex for single ticker (already handled)
170
- initial = data['Adj Close'].iloc[0]
171
- final = data['Adj Close'].iloc[-1]
172
- start_date = data.index[0]
173
- end_date = data.index[-1]
174
- days = (end_date - start_date).days
175
- years = days / 365.25
176
- if years > 0 and pd.notna(initial) and pd.notna(final):
177
- cagr = ((final / initial) ** (1 / years) - 1) * 100
178
- responses.append(f"- {ticker}: ~{cagr:.2f}%")
179
- else:
180
- responses.append(f"- {ticker}: Invalid period or missing price data.")
181
- else:
182
- responses.append(f"- {ticker}: No historical data available between {start_year} and {end_year}.")
183
- except Exception as e:
184
- logger.error(f"Error calculating CAGR for {ticker}: {str(e)}")
185
- responses.append(f"- {ticker}: Error calculating CAGR - {str(e)}")
186
-
187
- full_response = f"CAGR for the requested stocks from {start_year} to {end_year}:\n" + "\n".join(responses) + "\n- Represents average annual returns with compounding\n- Past performance not indicative of future results\n- Consult a financial advisor"
188
- full_response = re.sub(r'<\|(?:im_start|im_end|system|user|assistant)\|>|</s>|\[END\]', '', full_response).strip() # Clean any trailing tokens
189
-
190
- # Estimate token count to ensure response fits within max_new_tokens
191
- response_tokens = len(llm.tokenize(full_response.encode("utf-8"), add_bos=False))
192
- if response_tokens > max_new_tokens:
193
- logger.warning(f"CAGR response tokens ({response_tokens}) exceed max_new_tokens ({max_new_tokens}). Truncating to first complete sentence.")
194
- sentence_endings = ['.', '!', '?']
195
- first_sentence_end = min([full_response.find(ending) + 1 for ending in sentence_endings if full_response.find(ending) != -1], default=len(full_response))
196
- full_response = full_response[:first_sentence_end] if first_sentence_end > 0 else "Response truncated due to length; please increase Max New Tokens."
197
-
198
- logger.info("CAGR response generated.")
199
- yield full_response
200
- return
201
-
202
- # Build conversation messages (limit history to last 3 for speed)
203
- conversation = [{"role": "system", "content": system_prompt}]
204
- for msg in chat_history[-3:]: # Reduced from 5 to 3 for faster processing
205
- if msg["role"] == "user":
206
- conversation.append({"role": "user", "content": msg["content"]})
207
- elif msg["role"] == "assistant":
208
- conversation.append({"role": "assistant", "content": msg["content"]})
209
- conversation.append({"role": "user", "content": message})
210
-
211
- # Approximate token length check and truncate if necessary
212
- prompt_text = "\n".join(d["content"] for d in conversation)
213
- input_tokens = llm.tokenize(prompt_text.encode("utf-8"), add_bos=False)
214
- while len(input_tokens) > MAX_INPUT_TOKEN_LENGTH:
215
- logger.warning(f"Input tokens ({len(input_tokens)}) exceed limit ({MAX_INPUT_TOKEN_LENGTH}). Truncating history.")
216
- if len(conversation) > 2: # Preserve system prompt and current user message
217
- conversation.pop(1) # Remove oldest user/assistant pair
218
- prompt_text = "\n".join(d["content"] for d in conversation)
219
- input_tokens = llm.tokenize(prompt_text.encode("utf-8"), add_bos=False)
220
  else:
221
- yield "Error: Input is too long even after truncation. Please shorten your query."
222
  return
223
 
224
- # Generate response with sentence boundary checking and token cleanup
225
- try:
226
- response = ""
227
- sentence_buffer = ""
228
- token_count = 0
229
- stream = llm.create_chat_completion(
230
- messages=conversation,
231
- max_tokens=max_new_tokens,
232
- temperature=temperature,
233
- top_p=top_p,
234
- top_k=top_k,
235
- repeat_penalty=repetition_penalty,
236
- stream=True
237
- )
238
- sentence_endings = ['.', '!', '?']
239
- for chunk in stream:
240
- delta = chunk["choices"][0]["delta"]
241
- if "content" in delta and delta["content"] is not None:
242
- # Clean the chunk by removing ChatML tokens or similar
243
- cleaned_chunk = re.sub(r'<\|(?:im_start|im_end|system|user|assistant)\|>|</s>|\[END\]', '', delta["content"])
244
- if not cleaned_chunk:
245
- continue
246
- sentence_buffer += cleaned_chunk
247
- response += cleaned_chunk
248
- # Approximate token count for the chunk
249
- chunk_tokens = len(llm.tokenize(cleaned_chunk.encode("utf-8"), add_bos=False))
250
- token_count += chunk_tokens
251
- # Check for sentence boundary
252
- if any(sentence_buffer.strip().endswith(ending) for ending in sentence_endings):
253
- yield response
254
- sentence_buffer = "" # Clear buffer after yielding a complete sentence
255
- # Removed early truncation to allow full token utilization
256
-
257
- if chunk["choices"][0]["finish_reason"] is not None:
258
- # Yield any remaining complete sentence in the buffer
259
- if sentence_buffer.strip():
260
- last_sentence_end = max([sentence_buffer.rfind(ending) for ending in sentence_endings if sentence_buffer.rfind(ending) != -1], default=-1)
261
- if last_sentence_end != -1:
262
- response = response[:response.rfind(sentence_buffer) + last_sentence_end + 1]
263
- yield response
264
- else:
265
- yield response
266
- else:
267
- yield response
268
- break
269
- logger.info("Response generation completed.")
270
- except ValueError as ve:
271
- if "exceed context window" in str(ve):
272
- yield "Error: Prompt too long for context window. Please try a shorter query or clear history."
273
  else:
274
- logger.error(f"Error during response generation: {str(ve)}")
275
- yield f"Error generating response: {str(ve)}"
276
- except Exception as e:
277
- logger.error(f"Error during response generation: {str(e)}")
278
- yield f"Error generating response: {str(e)}"
279
-
280
- def process_portfolio(df, growth_rate):
281
- if df is None or len(df) == 0:
282
- return "", None
283
- # Convert to DataFrame if needed
284
- if not isinstance(df, pd.DataFrame):
285
- df = pd.DataFrame(df, columns=["Ticker", "Shares", "Avg Cost", "Current Price"])
286
- df = df.dropna(subset=["Ticker"])
287
- portfolio = {}
288
- for _, row in df.iterrows():
289
- ticker = row["Ticker"].upper() if pd.notna(row["Ticker"]) else None
290
- if not ticker:
291
- continue
292
- shares = float(row["Shares"]) if pd.notna(row["Shares"]) else 0
293
- cost = float(row["Avg Cost"]) if pd.notna(row["Avg Cost"]) else 0
294
- price = float(row["Current Price"]) if pd.notna(row["Current Price"]) else 0
295
- value = shares * price
296
- portfolio[ticker] = {'shares': shares, 'cost': cost, 'price': price, 'value': value}
297
- if not portfolio:
298
- return "", None
299
- total_value_now = sum(v['value'] for v in portfolio.values())
300
- allocations = {k: v['value'] / total_value_now for k, v in portfolio.items()} if total_value_now > 0 else {}
301
- fig_alloc, ax_alloc = plt.subplots()
302
- ax_alloc.pie(allocations.values(), labels=allocations.keys(), autopct='%1.1f%%')
303
- ax_alloc.set_title('Portfolio Allocation')
304
- buf_alloc = io.BytesIO()
305
- fig_alloc.savefig(buf_alloc, format='png')
306
- buf_alloc.seek(0)
307
- chart_alloc = Image.open(buf_alloc)
308
- plt.close(fig_alloc) # Close the figure to free memory
309
-
310
- def project_value(value, years, rate):
311
- return value * (1 + rate / 100) ** years
312
 
313
- total_value_1yr = sum(project_value(v['value'], 1, growth_rate) for v in portfolio.values())
314
- total_value_2yr = sum(project_value(v['value'], 2, growth_rate) for v in portfolio.values())
315
- total_value_5yr = sum(project_value(v['value'], 5, growth_rate) for v in portfolio.values())
316
- total_value_10yr = sum(project_value(v['value'], 10, growth_rate) for v in portfolio.values())
 
317
 
318
- data_str = (
319
- "User portfolio:\n" +
320
- "\n".join(f"- {k}: {v['shares']} shares, avg cost {v['cost']}, current price {v['price']}, value ${v['value']:,.2f}" for k, v in portfolio.items()) +
321
- f"\nTotal value now: ${total_value_now:,.2f}\nProjected (at {growth_rate}% annual growth):\n" +
322
- f"- 1 year: ${total_value_1yr:,.2f}\n- 2 years: ${total_value_2yr:,.2f}\n- 5 years: ${total_value_5yr:,.2f}\n- 10 years: ${total_value_10yr:,.2f}"
 
 
 
323
  )
324
- return data_str, chart_alloc
325
 
326
- def fetch_current_prices(df):
327
- if df is None or len(df) == 0:
328
- return df
329
- # Convert to DataFrame if needed
330
- if not isinstance(df, pd.DataFrame):
331
- df = pd.DataFrame(df, columns=["Ticker", "Shares", "Avg Cost", "Current Price"])
332
- for i in df.index:
333
- ticker = df.at[i, "Ticker"]
334
- if pd.notna(ticker) and ticker.strip():
335
- try:
336
- price = yf.Ticker(ticker.upper()).info.get('currentPrice', None)
337
- if price is not None:
338
- df.at[i, "Current Price"] = price
339
- except Exception as e:
340
- logger.warning(f"Failed to fetch price for {ticker}: {str(e)}")
341
- return df
342
 
343
- # Gradio interface setup
344
- with gr.Blocks(theme=themes.Soft(), css="""#chatbot {height: 800px; overflow: auto;}""") as demo:
345
  gr.Markdown(DESCRIPTION)
346
- chatbot = gr.Chatbot(label="FinChat", type="messages")
347
- msg = gr.Textbox(label="Ask a finance question", placeholder="e.g., 'What is CAGR?' or 'Average return for AAPL between 2010 and 2020'", info="Enter your query here. Portfolio data will be appended if provided.")
348
- with gr.Row():
349
- submit = gr.Button("Submit", variant="primary")
350
- clear = gr.Button("Clear")
351
- gr.Examples(
352
- examples=["What is CAGR?", "Average return for AAPL between 2010 and 2020", "Hi", "Explain compound interest"],
353
- inputs=msg,
354
- label="Example Queries"
355
- )
356
- with gr.Accordion("Enter Portfolio for Projections", open=False):
357
- portfolio_df = gr.Dataframe(
358
- headers=["Ticker", "Shares", "Avg Cost", "Current Price"],
359
- datatype=["str", "number", "number", "number"],
360
- row_count=3,
361
- col_count=(4, "fixed"),
362
- label="Portfolio Data",
363
- interactive=True
364
- )
365
- gr.Markdown("Enter your stocks here. You can add more rows by editing the table.")
366
- fetch_button = gr.Button("Fetch Current Prices", variant="secondary")
367
- fetch_button.click(fetch_current_prices, inputs=portfolio_df, outputs=portfolio_df)
368
- growth_rate = gr.Slider(minimum=5, maximum=50, step=5, value=10, label="Annual Growth Rate (%)", interactive=True, info="Select the assumed annual growth rate for projections.")
369
- growth_rate_label = gr.Markdown("**Selected Growth Rate: 10%**")
370
- with gr.Accordion("Advanced Settings", open=False):
371
- system_prompt = gr.Textbox(label="System Prompt", value=DEFAULT_SYSTEM_PROMPT, lines=6, info="Customize the AI's system prompt.")
372
- temperature = gr.Slider(label="Temperature", value=0.6, minimum=0.0, maximum=1.0, step=0.05, info="Controls randomness: lower is more deterministic.")
373
- top_p = gr.Slider(label="Top P", value=0.9, minimum=0.0, maximum=1.0, step=0.05, info="Nucleus sampling: higher includes more diverse tokens.")
374
- top_k = gr.Slider(label="Top K", value=50, minimum=1, maximum=100, step=1, info="Top-K sampling: limits to top K tokens.")
375
- repetition_penalty = gr.Slider(label="Repetition Penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05, info="Penalizes repeated tokens.")
376
- max_new_tokens = gr.Slider(label="Max New Tokens", value=DEFAULT_MAX_NEW_TOKENS, minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, info="Maximum length of generated response.")
377
  gr.Markdown(LICENSE)
378
 
379
- def update_growth_rate_label(growth_rate):
380
- return f"**Selected Growth Rate: {growth_rate}%**"
381
-
382
- def user(message, history):
383
- if not message:
384
- return "", history
385
- return "", history + [{"role": "user", "content": message}]
386
 
387
- def bot(history, sys_prompt, temp, tp, tk, rp, mnt, portfolio_df, growth_rate):
388
- if not history:
389
- logger.warning("History is empty, initializing with user message.")
390
- history = [{"role": "user", "content": ""}]
391
- message = history[-1]["content"]
392
- portfolio_data, chart_alloc = process_portfolio(portfolio_df, growth_rate)
393
- message += "\n" + portfolio_data
394
- history[-1]["content"] = message
395
- history.append({"role": "assistant", "content": ""})
396
- for new_text in generate(message, history[:-1], sys_prompt, mnt, temp, tp, tk, rp):
397
- history[-1]["content"] = new_text
398
- yield history, f"**Selected Growth Rate: {growth_rate}%**"
399
- if chart_alloc:
400
- history.append({"role": "assistant", "content": "", "image": chart_alloc})
401
- yield history, f"**Selected Growth Rate: {growth_rate}%**"
402
 
403
- growth_rate.change(update_growth_rate_label, inputs=growth_rate, outputs=growth_rate_label)
404
- submit.click(user, [msg, chatbot], [msg, chatbot], queue=False).then(
405
- bot, [chatbot, system_prompt, temperature, top_p, top_k, repetition_penalty, max_new_tokens, portfolio_df, growth_rate], [chatbot, growth_rate_label]
406
- )
407
- msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
408
- bot, [chatbot, system_prompt, temperature, top_p, top_k, repetition_penalty, max_new_tokens, portfolio_df, growth_rate], [chatbot, growth_rate_label]
409
  )
410
- clear.click(lambda: [], None, chatbot, queue=False)
411
 
412
- demo.queue(max_size=128).launch()
 
11
  import logging
12
  import pandas as pd
13
  import torch
14
+ import yfinance as yf
15
+ from datetime import datetime, timedelta
16
+ from math import sqrt
17
 
18
  # Set up logging
19
  logging.basicConfig(level=logging.INFO)
 
30
  else:
31
  logger.info("Installing llama-cpp-python without additional flags.")
32
  subprocess.check_call([sys.executable, "-m", "pip", "install", "llama-cpp-python", "--force-reinstall", "--upgrade", "--no-cache-dir"])
33
+ from llama_cpp import Llama
34
 
35
  # Install yfinance if not present (for CAGR calculations)
36
  try:
37
  import yfinance as yf
38
  except ModuleNotFoundError:
39
  subprocess.check_call([sys.executable, "-m", "pip", "install", "yfinance"])
40
+ import yfinance as yf
41
 
42
  # Import pandas for handling DataFrame column structures
43
  import pandas as pd
 
49
  import io
50
  except ModuleNotFoundError:
51
  subprocess.check_call([sys.executable, "-m", "pip", "install", "matplotlib", "pillow"])
52
+ import matplotlib.pyplot as plt
53
+ from PIL import Image
54
+ import io
55
 
56
  MAX_MAX_NEW_TOKENS = 512
57
  DEFAULT_MAX_NEW_TOKENS = 512
 
75
  llm = Llama(
76
  model_path=model_path,
77
  n_ctx=1024,
78
+ n_batch=1024, # Increased for faster processing
79
  n_threads=multiprocessing.cpu_count(),
80
  n_gpu_layers=n_gpu_layers,
81
+ chat_format="chatml" # Phi-2 uses ChatML format in llama.cpp
82
  )
83
+ logger.info(f"Model loaded successfully with n_gpu_layers= {n_gpu_layers}.")
84
  # Warm up the model for faster initial inference
85
  llm("Warm-up prompt", max_tokens=1, echo=False)
86
  logger.info("Model warm-up completed.")
 
93
 
94
  DEFAULT_SYSTEM_PROMPT = """You are FinChat, a knowledgeable AI assistant specializing in investing and finance. Provide accurate, helpful, reasoned, and concise answers to investing questions. Always base responses on reliable information and advise users to consult professionals for personalized advice.
95
  Always respond exclusively in English. Use bullet points for clarity.
96
+ Do not substitute or alter stock symbols provided in the user's query. Always use the exact tickers mentioned.
97
  Example:
98
+ User: average return for AAPL between 2010 and 2020
99
+ Assistant:
100
+ - AAPL CAGR (2010-2020): ~27.24%
101
  - Represents average annual return with compounding
102
+ - Past performance is not indicative of future results."""
 
103
 
104
+ # Function to calculate CAGR using yfinance
105
+ def calculate_cagr(ticker, start_date, end_date):
106
+ try:
107
+ data = yf.download(ticker, start=start_date, end=end_date)
108
+ if data.empty:
109
+ return None
110
+ start_price = data['Adj Close'].iloc[0]
111
+ end_price = data['Adj Close'].iloc[-1]
112
+ num_years = (data.index[-1] - data.index[0]).days / 365.25
113
+ cagr = (end_price / start_price) ** (1 / num_years) - 1
114
+ return cagr * 100 # Return as percentage
115
+ except Exception as e:
116
+ logger.error(f"Error calculating CAGR for {ticker}: {str(e)}")
117
+ return None
118
+
119
+ # New function to calculate risk metrics using yfinance
120
+ def calculate_risk_metrics(ticker, years=5):
121
+ try:
122
+ end_date = datetime.now().strftime('%Y-%m-%d')
123
+ start_date = (datetime.now() - timedelta(days=365 * years)).strftime('%Y-%m-%d')
124
+ data = yf.download(ticker, start=start_date, end=end_date)
125
+ if data.empty:
126
+ return None, None
127
+ returns = data['Adj Close'].pct_change().dropna()
128
+ volatility = returns.std() * sqrt(252) * 100 # Annualized volatility in percent
129
+ mean_return = returns.mean() * 252 # Annualized mean return
130
+ risk_free_rate = 0.02 # Assumed risk-free rate (e.g., 2%)
131
+ sharpe = (mean_return - risk_free_rate) / (volatility / 100) # Sharpe ratio
132
+ return volatility, sharpe
133
+ except Exception as e:
134
+ logger.error(f"Error calculating risk metrics for {ticker}: {str(e)}")
135
+ return None, None
136
 
137
+ # Assuming the generate function handles the chat logic (extended to include risk comparison)
138
  def generate(
139
  message: str,
140
+ history: list[tuple[str, str]],
141
+ system_prompt: str,
142
+ max_new_tokens: int,
143
+ temperature: float,
144
+ top_p: float,
145
+ top_k: int,
 
146
  ) -> Iterator[str]:
147
+ if not system_prompt:
148
+ system_prompt = DEFAULT_SYSTEM_PROMPT
149
+
150
+ # Detect CAGR query
151
+ cagr_match = re.search(r'average return for (\w+) between (\d{4}) and (\d{4})', message.lower())
152
+ if cagr_match:
153
+ ticker = cagr_match.group(1).upper()
154
+ start_year = cagr_match.group(2)
155
+ end_year = cagr_match.group(3)
156
+ start_date = f"{start_year}-01-01"
157
+ end_date = f"{end_year}-12-31"
158
+ cagr = calculate_cagr(ticker, start_date, end_date)
159
+ if cagr is not None:
160
+ yield f"- {ticker} CAGR ({start_year}-{end_year}): ~{cagr:.2f}%\n- Represents average annual return with compounding\n- Past performance is not indicative of future results.\n- Consult a financial advisor for personalized advice."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  else:
163
+ yield "Unable to calculate CAGR for the specified period."
164
  return
165
 
166
+ # Detect risk comparison query
167
+ risk_match = re.search(r'which stock is riskier (\w+) or (\w+)', message.lower())
168
+ if risk_match:
169
+ ticker1 = risk_match.group(1).upper()
170
+ ticker2 = risk_match.group(2).upper()
171
+ vol1, sharpe1 = calculate_risk_metrics(ticker1)
172
+ vol2, sharpe2 = calculate_risk_metrics(ticker2)
173
+ if vol1 is None or vol2 is None:
174
+ yield "Unable to fetch risk metrics for one or both tickers."
175
+ return
176
+ if vol1 > vol2:
177
+ riskier = ticker1
178
+ less_risky = ticker2
179
+ higher_vol = vol1
180
+ lower_vol = vol2
181
+ riskier_sharpe = sharpe1
182
+ less_sharpe = sharpe2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  else:
184
+ riskier = ticker2
185
+ less_risky = ticker1
186
+ higher_vol = vol2
187
+ lower_vol = vol1
188
+ riskier_sharpe = sharpe2
189
+ less_sharpe = sharpe1
190
+ yield f"- {riskier} is riskier compared to {less_risky}.\n- It has a higher annualized standard deviation ({higher_vol:.2f}% vs {lower_vol:.2f}%) and a lower Sharpe ratio ({riskier_sharpe:.2f} vs {less_sharpe:.2f}), indicating greater volatility and potentially lower risk-adjusted returns.\n- Calculations based on the past 5 years of data.\n- Past performance is not indicative of future results. Consult a financial advisor for personalized advice."
191
+ return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
+ # For other queries, fall back to LLM generation
194
+ conversation = [{"role": "system", "content": system_prompt}]
195
+ for user, assistant in history:
196
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
197
+ conversation.append({"role": "user", "content": message})
198
 
199
+ # Generate response using LLM (streamed)
200
+ response = llm.create_chat_completion(
201
+ messages=conversation,
202
+ max_tokens=max_new_tokens,
203
+ temperature=temperature,
204
+ top_p=top_p,
205
+ top_k=top_k,
206
+ stream=True
207
  )
 
208
 
209
+ partial_text = ""
210
+ for chunk in response:
211
+ if "content" in chunk["choices"][0]["delta"]:
212
+ partial_text += chunk["choices"][0]["delta"]["content"]
213
+ yield partial_text
 
 
 
 
 
 
 
 
 
 
 
214
 
215
+ # Gradio interface setup (assuming this is part of the original code)
216
+ with gr.Blocks(theme=themes.Default()) as demo:
217
  gr.Markdown(DESCRIPTION)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  gr.Markdown(LICENSE)
219
 
220
+ chatbot = gr.Chatbot()
221
+ msg = gr.Textbox(label="Enter your question")
222
+ with gr.Row():
223
+ submit = gr.Button("Submit")
224
+ clear = gr.Button("Clear")
 
 
225
 
226
+ advanced = gr.Accordion("Advanced Settings", open=False)
227
+ with advanced:
228
+ system_prompt = gr.Textbox(label="System Prompt", value=DEFAULT_SYSTEM_PROMPT, lines=6)
229
+ max_new_tokens = gr.Slider(minimum=1, maximum=MAX_MAX_NEW_TOKENS, value=DEFAULT_MAX_NEW_TOKENS, step=1, label="Max New Tokens")
230
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, label="Temperature")
231
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.1, label="Top P")
232
+ top_k = gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top K")
 
 
 
 
 
 
 
 
233
 
234
+ submit.click(generate, [msg, chatbot, system_prompt, max_new_tokens, temperature, top_p, top_k], chatbot, queue=False).then(
235
+ lambda: "", None, msg
 
 
 
 
236
  )
237
+ clear.click(lambda: None, None, chatbot)
238
 
239
+ demo.launch()