AnilNiraula commited on
Commit
c3e2fa0
·
verified ·
1 Parent(s): dcc2503

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -37
app.py CHANGED
@@ -5,45 +5,25 @@ from datetime import datetime, timedelta
5
  import difflib
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
  import torch
8
- import yfinance as yf # Primary source for accurate adjusted data
9
  from functools import lru_cache
10
  import pandas as pd
11
- import os
12
- from pycharts import CompanyClient # Retained, but with fallback
13
 
14
  # Define the list of tickers
15
  tickers = ['TSLA', 'PLTR', 'SOUN', 'MSFT']
16
 
17
- # YCharts API setup (requires a valid key; obtain from ycharts.com)
18
- ycharts_api_key = os.environ.get('YCHARTS_API_KEY', 'your-api-key-here') # Placeholder; set in environment
19
- company_client = CompanyClient(ycharts_api_key)
20
-
21
- # Prefetch stock data for all tickers at startup, preferring yfinance for adjusted prices
22
  all_data = {}
23
  try:
24
  now = datetime.now().strftime('%Y-%m-%d')
25
  for ticker in tickers:
26
- # Attempt YCharts fetch (assuming 'price' is unadjusted; fallback to yfinance)
27
- try:
28
- past = datetime(2020, 1, 1)
29
- series_rsp = company_client.get_series([ticker], ['price'], query_start_date=past, query_end_date=now)
30
- if ticker in series_rsp and 'price' in series_rsp[ticker]:
31
- df = pd.DataFrame({
32
- 'Date': series_rsp[ticker]['dates'],
33
- 'Close': series_rsp[ticker]['price']['values'] # Likely unadjusted
34
- }).set_index('Date')
35
- all_data[ticker] = df
36
- raise Exception("Fallback to yfinance") # Force fallback for accuracy in this update
37
- except:
38
- # Use yfinance for adjusted data
39
- all_data[ticker] = yf.download(ticker, start='2020-01-01', end=now)
40
  except Exception as e:
41
- print(f"Error prefetching data: {e}. Using yfinance exclusively.")
42
- all_data = {ticker: yf.download(ticker, start='2020-01-01', end=now.strftime('%Y-%m-%d'))
43
- for ticker in tickers}
44
 
45
  # Create a DataFrame with 'Adj Close' columns for each ticker
46
- adj_close_data = pd.DataFrame({ticker: data['Adj Close'] for ticker, data in all_data.items() if 'Adj Close' in data.columns})
47
 
48
  # Display the first few rows to verify (for debugging; remove in production)
49
  print(adj_close_data.head())
@@ -53,21 +33,28 @@ available_symbols = ['TSLA', 'MSFT', 'NVDA', 'GOOG', 'AMZN', 'SPY', 'AAPL', 'MET
53
 
54
  @lru_cache(maxsize=100)
55
  def fetch_stock_data(symbol, start_date, end_date):
56
- if symbol in all_data:
57
  # Use preloaded data and slice by date
58
  hist = all_data[symbol]
59
  return hist[(hist.index >= start_date) & (hist.index <= end_date)]
60
  else:
61
- # Fallback to on-demand fetch with yfinance for adjusted data (skip YCharts for CAGR accuracy)
62
  try:
63
  ticker = yf.Ticker(symbol)
64
- hist = ticker.history(start=start_date, end=end_date)
65
  return hist
66
  except Exception as e:
67
  print(f"Error fetching data for {symbol}: {e}")
68
  return None
69
 
70
  def parse_period(query):
 
 
 
 
 
 
 
71
  match = re.search(r'(\d+)\s*(year|month|week|day)s?', query.lower())
72
  if match:
73
  num = int(match.group(1))
@@ -89,10 +76,10 @@ def find_closest_symbol(input_symbol):
89
 
90
  def calculate_growth_rate(start_date, end_date, symbol):
91
  hist = fetch_stock_data(symbol, start_date.strftime('%Y-%m-%d'), end_date.strftime('%Y-%m-%d'))
92
- if hist is None or hist.empty or 'Adj Close' not in hist.columns:
93
  return None
94
- beginning_value = hist.iloc[0]['Adj Close']
95
- ending_value = hist.iloc[-1]['Adj Close']
96
  years = (end_date - start_date).days / 365.25
97
  if years <= 0:
98
  return 0
@@ -137,10 +124,16 @@ def generate_response(user_query, enable_thinking=False):
137
  "You are FinChat, a knowledgeable financial advisor. Always respond in a friendly, professional manner. "
138
  "For greetings like 'Hi' or 'Hello', reply warmly, e.g., 'Hi! I'm FinChat, your financial advisor. What can I help you with today regarding stocks, investments, or advice?' "
139
  "Provide accurate, concise advice based on data."
140
- # Assume continuation of original prompt here if additional text exists
141
  )
142
- # Original generation logic (tokenize, generate, decode) would follow here; integrate as per the full original code
 
143
 
144
- # Gradio interface setup (assumed from original; add if needed)
145
- # demo = gr.Interface(...)
146
- # demo.launch()
 
 
 
 
 
 
 
5
  import difflib
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
  import torch
8
+ import yfinance as yf
9
  from functools import lru_cache
10
  import pandas as pd
 
 
11
 
12
  # Define the list of tickers
13
  tickers = ['TSLA', 'PLTR', 'SOUN', 'MSFT']
14
 
15
+ # Prefetch stock data for all tickers at startup using yfinance
 
 
 
 
16
  all_data = {}
17
  try:
18
  now = datetime.now().strftime('%Y-%m-%d')
19
  for ticker in tickers:
20
+ all_data[ticker] = yf.download(ticker, start='2020-01-01', end=now, auto_adjust=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  except Exception as e:
22
+ print(f"Error prefetching data: {e}")
23
+ all_data = {ticker: pd.DataFrame() for ticker in tickers} # Initialize empty DataFrames on failure
 
24
 
25
  # Create a DataFrame with 'Adj Close' columns for each ticker
26
+ adj_close_data = pd.DataFrame({ticker: data['Close'] for ticker, data in all_data.items() if not data.empty})
27
 
28
  # Display the first few rows to verify (for debugging; remove in production)
29
  print(adj_close_data.head())
 
33
 
34
  @lru_cache(maxsize=100)
35
  def fetch_stock_data(symbol, start_date, end_date):
36
+ if symbol in all_data and not all_data[symbol].empty:
37
  # Use preloaded data and slice by date
38
  hist = all_data[symbol]
39
  return hist[(hist.index >= start_date) & (hist.index <= end_date)]
40
  else:
41
+ # Fetch on-demand with yfinance
42
  try:
43
  ticker = yf.Ticker(symbol)
44
+ hist = ticker.history(start=start_date, end=end_date, auto_adjust=True)
45
  return hist
46
  except Exception as e:
47
  print(f"Error fetching data for {symbol}: {e}")
48
  return None
49
 
50
  def parse_period(query):
51
+ # Enhanced to handle year ranges like "between 2010 and 2020"
52
+ range_match = re.search(r'between\s+(\d{4})\s+and\s+(\d{4})', query.lower())
53
+ if range_match:
54
+ start_year = int(range_match.group(1))
55
+ end_year = int(range_match.group(2))
56
+ return (datetime(end_year, 12, 31) - datetime(start_year, 1, 1))
57
+ # Fallback to original period parsing
58
  match = re.search(r'(\d+)\s*(year|month|week|day)s?', query.lower())
59
  if match:
60
  num = int(match.group(1))
 
76
 
77
  def calculate_growth_rate(start_date, end_date, symbol):
78
  hist = fetch_stock_data(symbol, start_date.strftime('%Y-%m-%d'), end_date.strftime('%Y-%m-%d'))
79
+ if hist is None or hist.empty or 'Close' not in hist.columns:
80
  return None
81
+ beginning_value = hist.iloc[0]['Close'] # Use 'Close' as yfinance with auto_adjust=True returns adjusted prices
82
+ ending_value = hist.iloc[-1]['Close']
83
  years = (end_date - start_date).days / 365.25
84
  if years <= 0:
85
  return 0
 
124
  "You are FinChat, a knowledgeable financial advisor. Always respond in a friendly, professional manner. "
125
  "For greetings like 'Hi' or 'Hello', reply warmly, e.g., 'Hi! I'm FinChat, your financial advisor. What can I help you with today regarding stocks, investments, or advice?' "
126
  "Provide accurate, concise advice based on data."
 
127
  )
128
+ # Placeholder for generation logic (tokenize, generate, decode)
129
+ return summary or "Please provide a specific stock or investment query."
130
 
131
+ # Gradio interface setup
132
+ demo = gr.Interface(
133
+ fn=generate_response,
134
+ inputs=[gr.Textbox(lines=2, placeholder="Enter your query (e.g., 'TSLA CAGR between 2010 and 2020')"), gr.Checkbox(label="Enable Thinking")],
135
+ outputs="text",
136
+ title="FinChat",
137
+ description="Ask about stock performance, CAGR, or investments."
138
+ )
139
+ demo.launch()