ahmedelbeshry commited on
Commit
7e02cc7
1 Parent(s): c2f18c6

Upload 23 files

Browse files
Files changed (23) hide show
  1. .gitattributes +35 -35
  2. .gitignore +0 -0
  3. Dockerfile +26 -0
  4. README.md +11 -11
  5. __init__.py +667 -0
  6. app.log +0 -0
  7. app.py +544 -0
  8. app2.py +560 -0
  9. bm25retriever.pkl +3 -0
  10. chain.py +28 -0
  11. chat.py +667 -0
  12. chatflask.py +646 -0
  13. config.py +18 -0
  14. embeddings.py +62 -0
  15. flasktest.py +49 -0
  16. index.html +70 -0
  17. llm.py +45 -0
  18. logging_config.py +38 -0
  19. main.py +100 -0
  20. rag.py +114 -0
  21. requirements.txt +30 -0
  22. retriever.py +53 -0
  23. tools.py +188 -0
.gitattributes CHANGED
@@ -1,35 +1,35 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
Binary file (38 Bytes). View file
 
Dockerfile ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use an official Python runtime as a parent image
2
+ FROM python:3.11-slim
3
+
4
+ # Set environment variables to avoid interactive prompts
5
+ ENV DEBIAN_FRONTEND=noninteractive
6
+
7
+ # Set the working directory in the container
8
+ WORKDIR /app
9
+
10
+ # Copy the current directory contents into the container at /app
11
+ COPY . /app
12
+
13
+ # Install any needed packages specified in requirements.txt
14
+ RUN pip install --no-cache-dir -r requirements.txt
15
+
16
+ # Copy and set environment variables from .env file
17
+ COPY .env .env
18
+
19
+ # Expose the port the Flask app runs on
20
+ EXPOSE 5000
21
+
22
+ # Expose the port the Streamlit app runs on
23
+ EXPOSE 8501
24
+
25
+ # Run the Flask app and Streamlit app using a single CMD
26
+ CMD ["streamlit", "run", "app.py"]
README.md CHANGED
@@ -1,11 +1,11 @@
1
- ---
2
- title: Financial Chatbot
3
- emoji: 🚀
4
- colorFrom: indigo
5
- colorTo: blue
6
- sdk: docker
7
- pinned: false
8
- license: mit
9
- ---
10
-
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: Financial Chatbot
3
+ emoji: 🚀
4
+ colorFrom: indigo
5
+ colorTo: blue
6
+ sdk: docker
7
+ pinned: false
8
+ license: mit
9
+ ---
10
+
11
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
__init__.py ADDED
@@ -0,0 +1,667 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from bs4 import BeautifulSoup
3
+ import yfinance as yf
4
+ import pandas as pd
5
+ from datetime import datetime, timedelta
6
+ import logging
7
+ from concurrent.futures import ThreadPoolExecutor, as_completed
8
+ from langchain_google_genai import ChatGoogleGenerativeAI
9
+ from config import Config
10
+ import numpy as np
11
+ from typing import Optional, Tuple, List, Dict
12
+ from rag import get_answer
13
+ import time
14
+ from tenacity import retry, stop_after_attempt, wait_exponential
15
+
16
+ # Set up logging
17
+ logging.basicConfig(level=logging.DEBUG,
18
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
19
+ handlers=[logging.FileHandler("app.log"),
20
+ logging.StreamHandler()])
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ # Initialize the Gemini model
25
+ llm = ChatGoogleGenerativeAI(api_key=Config.GEMINI_API_KEY, model="gemini-1.5-flash-latest", temperature=0.5)
26
+
27
+ # Configuration for Google Custom Search API
28
+ GOOGLE_API_KEY = Config.GOOGLE_API_KEY
29
+ SEARCH_ENGINE_ID = Config.SEARCH_ENGINE_ID
30
+
31
+
32
+ @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=2, max=8), reraise=True)
33
+ def invoke_llm(prompt):
34
+ return llm.invoke(prompt)
35
+
36
+
37
+ class DataSummarizer:
38
+ def __init__(self):
39
+ pass
40
+
41
+ def google_search(self, query: str) -> Optional[str]:
42
+ start_time = time.time()
43
+ try:
44
+ url = "https://www.googleapis.com/customsearch/v1"
45
+ params = {
46
+ 'key': GOOGLE_API_KEY,
47
+ 'cx': SEARCH_ENGINE_ID,
48
+ 'q': query
49
+ }
50
+ response = requests.get(url, params=params)
51
+ response.raise_for_status()
52
+ search_results = response.json()
53
+ logger.info("google_search took %.2f seconds", time.time() - start_time)
54
+
55
+ # Summarize the search results using Gemini
56
+ items = search_results.get('items', [])
57
+ content = "\n\n".join([f"{item.get('title', '')}\n{item.get('snippet', '')}" for item in items])
58
+ prompt = f"Summarize the following search results:\n\n{content}"
59
+ summary_response = invoke_llm(prompt)
60
+ return summary_response.content.strip()
61
+ except Exception as e:
62
+ logger.error(f"Error during Google Search API request: {e}")
63
+ return None
64
+
65
+ def extract_content_from_item(self, item: Dict) -> Optional[str]:
66
+ try:
67
+ snippet = item.get('snippet', '')
68
+ title = item.get('title', '')
69
+ return f"{title}\n{snippet}"
70
+ except Exception as e:
71
+ logger.error(f"Error extracting content from item: {e}")
72
+ return None
73
+
74
+ def calculate_moving_average(self, df: pd.DataFrame, window: int = 20) -> Optional[pd.Series]:
75
+ start_time = time.time()
76
+ try:
77
+ result = df['close'].rolling(window=window).mean()
78
+ logger.info("calculate_moving_average took %.2f seconds", time.time() - start_time)
79
+ return result
80
+ except Exception as e:
81
+ logger.error(f"Error calculating moving average: {e}")
82
+ return None
83
+
84
+ def calculate_rsi(self, df: pd.DataFrame, window: int = 14) -> Optional[pd.Series]:
85
+ start_time = time.time()
86
+ try:
87
+ delta = df['close'].diff()
88
+ gain = delta.where(delta > 0, 0).rolling(window=window).mean()
89
+ loss = -delta.where(delta < 0, 0).rolling(window=window).mean()
90
+ rs = gain / loss
91
+ result = 100 - (100 / (1 + rs))
92
+ logger.info("calculate_rsi took %.2f seconds", time.time() - start_time)
93
+ return result
94
+ except Exception as e:
95
+ logger.error(f"Error calculating RSI: {e}")
96
+ return None
97
+
98
+ def calculate_ema(self, df: pd.DataFrame, window: int = 20) -> Optional[pd.Series]:
99
+ start_time = time.time()
100
+ try:
101
+ result = df['close'].ewm(span=window, adjust=False).mean()
102
+ logger.info("calculate_ema took %.2f seconds", time.time() - start_time)
103
+ return result
104
+ except Exception as e:
105
+ logger.error(f"Error calculating EMA: {e}")
106
+ return None
107
+
108
+ def calculate_bollinger_bands(self, df: pd.DataFrame, window: int = 20) -> Optional[pd.DataFrame]:
109
+ start_time = time.time()
110
+ try:
111
+ ma = df['close'].rolling(window=window).mean()
112
+ std = df['close'].rolling(window=window).std()
113
+ upper_band = ma + (std * 2)
114
+ lower_band = ma - (std * 2)
115
+ result = pd.DataFrame({'MA': ma, 'Upper Band': upper_band, 'Lower Band': lower_band})
116
+ logger.info("calculate_bollinger_bands took %.2f seconds", time.time() - start_time)
117
+ return result
118
+ except Exception as e:
119
+ logger.error(f"Error calculating Bollinger Bands: {e}")
120
+ return None
121
+
122
+ def calculate_macd(self, df: pd.DataFrame, short_window: int = 12, long_window: int = 26, signal_window: int = 9) -> \
123
+ Optional[pd.DataFrame]:
124
+ start_time = time.time()
125
+ try:
126
+ short_ema = df['close'].ewm(span=short_window, adjust=False).mean()
127
+ long_ema = df['close'].ewm(span=long_window, adjust=False).mean()
128
+ macd = short_ema - long_ema
129
+ signal = macd.ewm(span=signal_window, adjust=False).mean()
130
+ result = pd.DataFrame({'MACD': macd, 'Signal Line': signal})
131
+ logger.info("calculate_macd took %.2f seconds", time.time() - start_time)
132
+ return result
133
+ except Exception as e:
134
+ logger.error(f"Error calculating MACD: {e}")
135
+ return None
136
+
137
+ def calculate_volatility(self, df: pd.DataFrame, window: int = 20) -> Optional[pd.Series]:
138
+ start_time = time.time()
139
+ try:
140
+ log_returns = np.log(df['close'] / df['close'].shift(1))
141
+ result = log_returns.rolling(window=window).std() * np.sqrt(window)
142
+ logger.info("calculate_volatility took %.2f seconds", time.time() - start_time)
143
+ return result
144
+ except Exception as e:
145
+ logger.error(f"Error calculating volatility: {e}")
146
+ return None
147
+
148
+ def calculate_atr(self, df: pd.DataFrame, window: int = 14) -> Optional[pd.Series]:
149
+ start_time = time.time()
150
+ try:
151
+ high_low = df['high'] - df['low']
152
+ high_close = np.abs(df['high'] - df['close'].shift())
153
+ low_close = np.abs(df['low'] - df['close'].shift())
154
+ true_range = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
155
+ result = true_range.rolling(window=window).mean()
156
+ logger.info("calculate_atr took %.2f seconds", time.time() - start_time)
157
+ return result
158
+ except Exception as e:
159
+ logger.error(f"Error calculating ATR: {e}")
160
+ return None
161
+
162
+ def calculate_obv(self, df: pd.DataFrame) -> Optional[pd.Series]:
163
+ start_time = time.time()
164
+ try:
165
+ result = (np.sign(df['close'].diff()) * df['volume']).fillna(0).cumsum()
166
+ logger.info("calculate_obv took %.2f seconds", time.time() - start_time)
167
+ return result
168
+ except Exception as e:
169
+ logger.error(f"Error calculating OBV: {e}")
170
+ return None
171
+
172
+ def calculate_yearly_summary(self, df: pd.DataFrame) -> Optional[pd.DataFrame]:
173
+ start_time = time.time()
174
+ try:
175
+ df['year'] = pd.to_datetime(df['date']).dt.year
176
+ yearly_summary = df.groupby('year').agg({
177
+ 'close': ['mean', 'max', 'min'],
178
+ 'volume': 'sum'
179
+ })
180
+ yearly_summary.columns = ['_'.join(col) for col in yearly_summary.columns]
181
+ logger.info("calculate_yearly_summary took %.2f seconds", time.time() - start_time)
182
+ return yearly_summary
183
+ except Exception as e:
184
+ logger.error(f"Error calculating yearly summary: {e}")
185
+ return None
186
+
187
+ def get_full_last_year(self, df: pd.DataFrame) -> Optional[pd.DataFrame]:
188
+ start_time = time.time()
189
+ try:
190
+ today = datetime.today().date()
191
+ last_year_start = datetime(today.year - 1, 1, 1).date()
192
+ last_year_end = datetime(today.year - 1, 12, 31).date()
193
+ mask = (df['date'] >= last_year_start) & (df['date'] <= last_year_end)
194
+ result = df.loc[mask]
195
+ logger.info("get_full_last_year took %.2f seconds", time.time() - start_time)
196
+ return result
197
+ except Exception as e:
198
+ logger.error(f"Error filtering data for the last year: {e}")
199
+ return None
200
+
201
+ def calculate_ytd_performance(self, df: pd.DataFrame) -> Optional[float]:
202
+ start_time = time.time()
203
+ try:
204
+ today = datetime.today().date()
205
+ year_start = datetime(today.year, 1, 1).date()
206
+ mask = (df['date'] >= year_start) & (df['date'] <= today)
207
+ ytd_data = df.loc[mask]
208
+ opening_price = ytd_data.iloc[0]['open']
209
+ closing_price = ytd_data.iloc[-1]['close']
210
+ result = ((closing_price - opening_price) / opening_price) * 100
211
+ logger.info("calculate_ytd_performance took %.2f seconds", time.time() - start_time)
212
+ return result
213
+ except Exception as e:
214
+ logger.error(f"Error calculating YTD performance: {e}")
215
+ return None
216
+
217
+ def calculate_pe_ratio(self, current_price: float, eps: float) -> Optional[float]:
218
+ start_time = time.time()
219
+ try:
220
+ if eps == 0:
221
+ raise ValueError("EPS cannot be zero for P/E ratio calculation.")
222
+ result = current_price / eps
223
+ logger.info("calculate_pe_ratio took %.2f seconds", time.time() - start_time)
224
+ return result
225
+ except Exception as e:
226
+ logger.error(f"Error calculating P/E ratio: {e}")
227
+ return None
228
+
229
+ def fetch_google_snippet(self, query: str) -> Optional[str]:
230
+ try:
231
+ search_url = f"https://www.google.com/search?q={query}"
232
+ headers = {
233
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.36"
234
+ }
235
+ response = requests.get(search_url, headers=headers)
236
+ soup = BeautifulSoup(response.text, 'html.parser')
237
+ snippet_classes = [
238
+ 'BNeawe iBp4i AP7Wnd',
239
+ 'BNeawe s3v9rd AP7Wnd',
240
+ 'BVG0Nb',
241
+ 'kno-rdesc'
242
+ ]
243
+ snippet = None
244
+ for cls in snippet_classes:
245
+ snippet = soup.find('div', class_=cls)
246
+ if snippet:
247
+ break
248
+ return snippet.get_text() if snippet else "Snippet not found."
249
+ except Exception as e:
250
+ logger.error(f"Error fetching Google snippet: {e}")
251
+ return None
252
+
253
+
254
+ def extract_ticker_from_response(response: str) -> Optional[str]:
255
+ start_time = time.time()
256
+ try:
257
+ if "is **" in response and "**." in response:
258
+ result = response.split("is **")[1].split("**.")[0].strip()
259
+ logger.info("extract_ticker_from_response took %.2f seconds", time.time() - start_time)
260
+ return result
261
+ result = response.strip()
262
+ logger.info("extract_ticker_from_response took %.2f seconds", time.time() - start_time)
263
+ return result
264
+ except Exception as e:
265
+ logger.error(f"Error extracting ticker from response: {e}")
266
+ return None
267
+
268
+
269
+ def detect_translate_entity_and_ticker(query: str) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]:
270
+ try:
271
+ start_time = time.time()
272
+
273
+ # Step 1: Detect Language
274
+ prompt = f"Detect the language for the following text: {query}"
275
+ response = invoke_llm(prompt)
276
+ detected_language = response.content.strip()
277
+ logger.info(f"Language detected: {detected_language}")
278
+
279
+ # Step 2: Translate to English (if necessary)
280
+ translated_query = query
281
+ if detected_language != "English":
282
+ prompt = f"Translate the following text to English: {query}"
283
+ response = invoke_llm(prompt)
284
+ translated_query = response.content.strip()
285
+ logger.info(f"Translation completed: {translated_query}")
286
+ print(f"Translation: {translated_query}")
287
+
288
+ # Step 3: Detect Entity
289
+ prompt = f"Detect the entity in the following text that is a company name: {translated_query}"
290
+ response = invoke_llm(prompt)
291
+ detected_entity = response.content.strip()
292
+ logger.info(f"Entity detected: {detected_entity}")
293
+ print(f"Entity: {detected_entity}")
294
+
295
+ if not detected_entity:
296
+ logger.error("No entity detected")
297
+ return detected_language, None, translated_query, None
298
+
299
+ # Step 4: Get Stock Ticker
300
+ prompt = f"What is the stock ticker symbol for the company {detected_entity}?"
301
+ response = invoke_llm(prompt)
302
+ stock_ticker = extract_ticker_from_response(response.content.strip())
303
+
304
+ if not stock_ticker:
305
+ logger.error("No stock ticker detected")
306
+ return detected_language, detected_entity, translated_query, None
307
+
308
+ logger.info("detect_translate_entity_and_ticker took %.2f seconds", time.time() - start_time)
309
+ return detected_language, detected_entity, translated_query, stock_ticker
310
+ except Exception as e:
311
+ logger.error(f"Error in detecting, translating, or extracting entity and ticker: {e}")
312
+ return None, None, None, None
313
+
314
+
315
+ def fetch_stock_data_yahoo(symbol: str) -> pd.DataFrame:
316
+ start_time = time.time()
317
+ try:
318
+ stock = yf.Ticker(symbol)
319
+ logger.info(f"Fetching data for symbol: {symbol}")
320
+
321
+ end_date = datetime.now()
322
+ start_date = end_date - timedelta(days=3 * 365)
323
+
324
+ historical_data = stock.history(start=start_date, end=end_date)
325
+ if historical_data.empty:
326
+ raise ValueError(f"No historical data found for symbol: {symbol}")
327
+
328
+ historical_data = historical_data.rename(
329
+ columns={"Open": "open", "High": "high", "Low": "low", "Close": "close", "Volume": "volume"}
330
+ )
331
+
332
+ historical_data.reset_index(inplace=True)
333
+ historical_data['date'] = historical_data['Date'].dt.date
334
+ historical_data = historical_data.drop(columns=['Date'])
335
+ historical_data = historical_data[['date', 'open', 'high', 'low', 'close', 'volume']]
336
+
337
+ if 'close' not in historical_data.columns:
338
+ raise KeyError("The historical data must contain a 'close' column.")
339
+
340
+ logger.info("fetch_stock_data_yahoo took %.2f seconds", time.time() - start_time)
341
+ return historical_data
342
+ except Exception as e:
343
+ logger.error(f"Failed to fetch stock data for {symbol} from Yahoo Finance: {e}")
344
+ return pd.DataFrame()
345
+
346
+
347
+ def fetch_current_stock_price(symbol: str) -> Optional[float]:
348
+ start_time = time.time()
349
+ try:
350
+ stock = yf.Ticker(symbol)
351
+ result = stock.info['currentPrice']
352
+ logger.info("fetch_current_stock_price took %.2f seconds", time.time() - start_time)
353
+ return result
354
+ except Exception as e:
355
+ logger.error(f"Failed to fetch current stock price for {symbol}: {e}")
356
+ return None
357
+
358
+
359
+ def format_stock_data_for_gemini(stock_data: pd.DataFrame) -> str:
360
+ start_time = time.time()
361
+ try:
362
+ if stock_data.empty:
363
+ return "No historical data available."
364
+
365
+ formatted_data = "Historical stock data for the last three years:\n\n"
366
+ formatted_data += "Date | Open | High | Low | Close | Volume\n"
367
+ formatted_data += "------------------------------------------------------\n"
368
+
369
+ for index, row in stock_data.iterrows():
370
+ formatted_data += f"{row['date']} | {row['open']:.2f} | {row['high']:.2f} | {row['low']:.2f} | {row['close']:.2f} | {int(row['volume'])}\n"
371
+
372
+ logger.info("format_stock_data_for_gemini took %.2f seconds", time.time() - start_time)
373
+ return formatted_data
374
+ except Exception as e:
375
+ logger.error(f"Error formatting stock data for Gemini: {e}")
376
+ return "Error formatting stock data."
377
+
378
+
379
+ def fetch_company_info_yahoo(symbol: str) -> Dict:
380
+ start_time = time.time()
381
+ try:
382
+ if not symbol:
383
+ return {"error": "Invalid symbol"}
384
+
385
+ stock = yf.Ticker(symbol)
386
+ company_info = stock.info
387
+ logger.info("fetch_company_info_yahoo took %.2f seconds", time.time() - start_time)
388
+ return {
389
+ "name": company_info.get("longName", "N/A"),
390
+ "sector": company_info.get("sector", "N/A"),
391
+ "industry": company_info.get("industry", "N/A"),
392
+ "marketCap": company_info.get("marketCap", "N/A"),
393
+ "summary": company_info.get("longBusinessSummary", "N/A"),
394
+ "website": company_info.get("website", "N/A"),
395
+ "address": company_info.get("address1", "N/A"),
396
+ "city": company_info.get("city", "N/A"),
397
+ "state": company_info.get("state", "N/A"),
398
+ "country": company_info.get("country", "N/A"),
399
+ "phone": company_info.get("phone", "N/A")
400
+ }
401
+ except Exception as e:
402
+ logger.error(f"Error fetching company info for {symbol}: {e}")
403
+ return {"error": str(e)}
404
+
405
+
406
+ def format_company_info_for_gemini(company_info: Dict) -> str:
407
+ start_time = time.time()
408
+ try:
409
+ if "error" in company_info:
410
+ return f"Error fetching company info: {company_info['error']}"
411
+
412
+ formatted_info = (f"\nCompany Information:\n"
413
+ f"Name: {company_info['name']}\n"
414
+ f"Sector: {company_info['sector']}\n"
415
+ f"Industry: {company_info['industry']}\n"
416
+ f"Market Cap: {company_info['marketCap']}\n"
417
+ f"Summary: {company_info['summary']}\n"
418
+ f"Website: {company_info['website']}\n"
419
+ f"Address: {company_info['address']}, {company_info['city']}, {company_info['state']}, {company_info['country']}\n"
420
+ f"Phone: {company_info['phone']}\n")
421
+
422
+ logger.info("format_company_info_for_gemini took %.2f seconds", time.time() - start_time)
423
+ return formatted_info
424
+ except Exception as e:
425
+ logger.error(f"Error formatting company info for Gemini: {e}")
426
+ return "Error formatting company info."
427
+
428
+
429
+ def fetch_company_news_yahoo(symbol: str) -> List[Dict]:
430
+ start_time = time.time()
431
+ try:
432
+ stock = yf.Ticker(symbol)
433
+ news = stock.news
434
+ if not news:
435
+ raise ValueError(f"No news found for symbol: {symbol}")
436
+ logger.info("fetch_company_news_yahoo took %.2f seconds", time.time() - start_time)
437
+ return news
438
+ except Exception as e:
439
+ logger.error(f"Failed to fetch news for {symbol} from Yahoo Finance: {e}")
440
+ return []
441
+
442
+
443
+ def format_company_news_for_gemini(news: List[Dict]) -> str:
444
+ start_time = time.time()
445
+ try:
446
+ if not news:
447
+ return "No news available."
448
+
449
+ formatted_news = "Latest company news:\n\n"
450
+ for article in news:
451
+ formatted_news += (f"Title: {article['title']}\n"
452
+ f"Publisher: {article['publisher']}\n"
453
+ f"Link: {article['link']}\n"
454
+ f"Published: {article['providerPublishTime']}\n\n")
455
+
456
+ logger.info("format_company_news_for_gemini took %.2f seconds", time.time() - start_time)
457
+ return formatted_news
458
+ except Exception as e:
459
+ logger.error(f"Error formatting company news for Gemini: {e}")
460
+ return "Error formatting company news."
461
+
462
+
463
+ def send_to_gemini_for_summarization(content: str) -> str:
464
+ start_time = time.time()
465
+ try:
466
+ unified_content = " ".join(content)
467
+ prompt = f"Summarize the main points of this article.\n\n{unified_content}"
468
+ response = invoke_llm(prompt)
469
+ logger.info("send_to_gemini_for_summarization took %.2f seconds", time.time() - start_time)
470
+ return response.content.strip()
471
+ except Exception as e:
472
+ logger.error(f"Error sending content to Gemini for summarization: {e}")
473
+ return "Error summarizing content."
474
+
475
+
476
+ def answer_question_with_data(question: str, data: Dict) -> str:
477
+ start_time = time.time()
478
+ try:
479
+ data_str = ""
480
+ for key, value in data.items():
481
+ data_str += f"{key}:\n{value}\n\n"
482
+
483
+ prompt = (f"You are a financial advisor. Begin your answer by stating that and only give the answer after.\n"
484
+ f"Using the following data, answer this question: {question}\n\nData:\n{data_str}\n"
485
+ f"Make your answer in the best form and professional.\n"
486
+ f"Don't say anything about the source of the data.\n"
487
+ f"If you don't have the data to answer, say this data is not available yet. If the data is not available in the stock history data, say this was a weekend and there is no data for it.")
488
+ response = invoke_llm(prompt)
489
+ logger.info("answer_question_with_data took %.2f seconds", time.time() - start_time)
490
+ return response.content.strip()
491
+ except Exception as e:
492
+ logger.error(f"Error answering question with data: {e}")
493
+ return "Error answering question."
494
+
495
+
496
+ def calculate_metrics(stock_data: pd.DataFrame, summarizer: DataSummarizer, company_info: Dict) -> Dict[str, str]:
497
+ start_time = time.time()
498
+ try:
499
+ moving_average = summarizer.calculate_moving_average(stock_data)
500
+ rsi = summarizer.calculate_rsi(stock_data)
501
+ ema = summarizer.calculate_ema(stock_data)
502
+ bollinger_bands = summarizer.calculate_bollinger_bands(stock_data)
503
+ macd = summarizer.calculate_macd(stock_data)
504
+ volatility = summarizer.calculate_volatility(stock_data)
505
+ atr = summarizer.calculate_atr(stock_data)
506
+ obv = summarizer.calculate_obv(stock_data)
507
+ yearly_summary = summarizer.calculate_yearly_summary(stock_data)
508
+ ytd_performance = summarizer.calculate_ytd_performance(stock_data)
509
+
510
+ eps = company_info.get('trailingEps', None)
511
+ if eps:
512
+ current_price = stock_data.iloc[-1]['close']
513
+ pe_ratio = summarizer.calculate_pe_ratio(current_price, eps)
514
+ formatted_metrics = {
515
+ "Moving Average": moving_average.to_string(),
516
+ "RSI": rsi.to_string(),
517
+ "EMA": ema.to_string(),
518
+ "Bollinger Bands": bollinger_bands.to_string(),
519
+ "MACD": macd.to_string(),
520
+ "Volatility": volatility.to_string(),
521
+ "ATR": atr.to_string(),
522
+ "OBV": obv.to_string(),
523
+ "Yearly Summary": yearly_summary.to_string(),
524
+ "YTD Performance": f"{ytd_performance:.2f}%",
525
+ "P/E Ratio": f"{pe_ratio:.2f}"
526
+ }
527
+ else:
528
+ formatted_metrics = {
529
+ "Moving Average": moving_average.to_string(),
530
+ "RSI": rsi.to_string(),
531
+ "EMA": ema.to_string(),
532
+ "Bollinger Bands": bollinger_bands.to_string(),
533
+ "MACD": macd.to_string(),
534
+ "Volatility": volatility.to_string(),
535
+ "ATR": atr.to_string(),
536
+ "OBV": obv.to_string(),
537
+ "Yearly Summary": yearly_summary.to_string(),
538
+ "YTD Performance": f"{ytd_performance:.2f}%"
539
+ }
540
+
541
+ logger.info("calculate_metrics took %.2f seconds", time.time() - start_time)
542
+ return formatted_metrics
543
+ except Exception as e:
544
+ logger.error(f"Error calculating metrics: {e}")
545
+ return {"Error": "Error calculating metrics"}
546
+
547
+
548
+ def prepare_data(formatted_stock_data: str, formatted_company_info: str, formatted_company_news: str,
549
+ google_results: str, formatted_metrics: Dict[str, str], google_snippet: str, rag_response: str) -> \
550
+ Dict[str, str]:
551
+ start_time = time.time()
552
+ collected_data = {
553
+ "Formatted Stock Data": formatted_stock_data,
554
+ "Formatted Company Info": formatted_company_info,
555
+ "Formatted Company News": formatted_company_news,
556
+ "Google Search Results": google_results,
557
+ "Google Snippet": google_snippet,
558
+ "RAG Response": rag_response,
559
+ "Calculations": formatted_metrics
560
+ }
561
+ collected_data.update(formatted_metrics)
562
+ logger.info("prepare_data took %.2f seconds", time.time() - start_time)
563
+ return collected_data
564
+
565
+
566
+ def main():
567
+ print("Welcome to the Financial Data Chatbot. How can I assist you today?")
568
+
569
+ summarizer = DataSummarizer()
570
+ conversation_history = []
571
+
572
+ while True:
573
+ user_input = input("You: ")
574
+
575
+ if user_input.lower() in ['exit', 'quit', 'bye']:
576
+ print("Goodbye! Have a great day!")
577
+ break
578
+
579
+ conversation_history.append(f"You: {user_input}")
580
+
581
+ try:
582
+ # Detect language, entity, translation, and stock ticker
583
+ language, entity, translation, stock_ticker = detect_translate_entity_and_ticker(user_input)
584
+
585
+ logger.info(
586
+ f"Detected Language: {language}, Entity: {entity}, Translation: {translation}, Stock Ticker: {stock_ticker}")
587
+
588
+ if entity and stock_ticker:
589
+ with ThreadPoolExecutor() as executor:
590
+ futures = {
591
+ executor.submit(fetch_stock_data_yahoo, stock_ticker): "stock_data",
592
+ executor.submit(fetch_company_info_yahoo, stock_ticker): "company_info",
593
+ executor.submit(fetch_company_news_yahoo, stock_ticker): "company_news",
594
+ executor.submit(fetch_current_stock_price, stock_ticker): "current_stock_price",
595
+ executor.submit(get_answer, user_input): "rag_response",
596
+ executor.submit(summarizer.google_search, user_input): "google_results",
597
+ executor.submit(summarizer.fetch_google_snippet, user_input): "google_snippet"
598
+ }
599
+ results = {futures[future]: future.result() for future in as_completed(futures)}
600
+
601
+ stock_data = results.get("stock_data", pd.DataFrame())
602
+ formatted_stock_data = format_stock_data_for_gemini(
603
+ stock_data) if not stock_data.empty else "No historical data available."
604
+
605
+ company_info = results.get("company_info", {})
606
+ formatted_company_info = format_company_info_for_gemini(
607
+ company_info) if company_info else "No company info available."
608
+
609
+ company_news = results.get("company_news", [])
610
+ formatted_company_news = format_company_news_for_gemini(
611
+ company_news) if company_news else "No news available."
612
+
613
+ current_stock_price = results.get("current_stock_price", None)
614
+
615
+ formatted_metrics = calculate_metrics(stock_data, summarizer,
616
+ company_info) if not stock_data.empty else {
617
+ "Error": "No stock data for metrics"}
618
+
619
+ google_results = results.get("google_results", "No additional news found through Google Search.")
620
+ google_snippet = results.get("google_snippet", "Snippet not found.")
621
+
622
+ rag_response = results.get("rag_response", "No response from RAG.")
623
+
624
+ collected_data = prepare_data(formatted_stock_data, formatted_company_info, formatted_company_news,
625
+ google_results, formatted_metrics, google_snippet, rag_response)
626
+ collected_data[
627
+ "Current Stock Price"] = f"${current_stock_price:.2f}" if current_stock_price is not None else "N/A"
628
+
629
+ conversation_history.append(f"RAG Response: {rag_response}")
630
+ history_context = "\n".join(conversation_history)
631
+
632
+ answer = answer_question_with_data(f"{history_context}\n\nUser's query: {translation}", collected_data)
633
+
634
+ print(f"\nBot: {answer}")
635
+ conversation_history.append(f"Bot: {answer}")
636
+
637
+ else:
638
+ with ThreadPoolExecutor() as executor:
639
+ futures = {
640
+ executor.submit(get_answer, user_input): "rag_response",
641
+ executor.submit(summarizer.google_search, user_input): "google_results",
642
+ executor.submit(summarizer.fetch_google_snippet, user_input): "google_snippet"
643
+ }
644
+ results = {futures[future]: future.result() for future in as_completed(futures)}
645
+
646
+ google_results = results.get("google_results", "No additional news found through Google Search.")
647
+ google_snippet = results.get("google_snippet", "Snippet not found.")
648
+ rag_response = results.get("rag_response", "No response from RAG.")
649
+
650
+ collected_data = prepare_data("", "", "", google_results, {}, google_snippet, rag_response)
651
+
652
+ conversation_history.append(f"RAG Response: {rag_response}")
653
+ history_context = "\n".join(conversation_history)
654
+
655
+ answer = answer_question_with_data(f"{history_context}\n\nUser's query: {user_input}", collected_data)
656
+
657
+ print(f"\nBot: {answer}")
658
+ conversation_history.append(f"Bot: {answer}")
659
+
660
+ except Exception as e:
661
+ logger.error(f"An error occurred: {e}")
662
+ response = "An error occurred while processing your request. Please try again later."
663
+ print(f"Bot: {response}")
664
+ conversation_history.append(f"Bot: {response}")
665
+
666
+ if __name__ == "__main__":
667
+ main()
app.log ADDED
The diff for this file is too large to render. See raw diff
 
app.py ADDED
@@ -0,0 +1,544 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ import requests
3
+ from bs4 import BeautifulSoup
4
+ import yfinance as yf
5
+ import pandas as pd
6
+ from datetime import datetime, timedelta
7
+ import logging
8
+ from concurrent.futures import ThreadPoolExecutor, as_completed
9
+ from langchain_google_genai import ChatGoogleGenerativeAI
10
+ from config import Config
11
+ import numpy as np
12
+ from typing import Optional, Tuple, List, Dict
13
+ from rag import get_answer
14
+ import time
15
+ from tenacity import retry, stop_after_attempt, wait_exponential
16
+ import threading
17
+ import streamlit as st
18
+ import json
19
+
20
+ # Initialize Flask app
21
+ app = Flask(__name__)
22
+
23
+ # Set up logging
24
+ logging.basicConfig(level=logging.DEBUG,
25
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
26
+ handlers=[logging.FileHandler("app.log"),
27
+ logging.StreamHandler()])
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+ # Initialize the Gemini model
32
+ llm = ChatGoogleGenerativeAI(api_key=Config.GEMINI_API_KEY, model="gemini-1.5-flash-latest", temperature=0.5)
33
+
34
+ # Configuration for Google Custom Search API
35
+ GOOGLE_API_KEY = Config.GOOGLE_API_KEY
36
+ SEARCH_ENGINE_ID = Config.SEARCH_ENGINE_ID
37
+
38
+ @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=2, max=8), reraise=True)
39
+ def invoke_llm(prompt):
40
+ return llm.invoke(prompt)
41
+
42
+ class DataSummarizer:
43
+ def google_search(self, query: str) -> Optional[str]:
44
+ try:
45
+ url = "https://www.googleapis.com/customsearch/v1"
46
+ params = {
47
+ 'key': GOOGLE_API_KEY,
48
+ 'cx': SEARCH_ENGINE_ID,
49
+ 'q': query
50
+ }
51
+ response = requests.get(url, params=params)
52
+ response.raise_for_status()
53
+ search_results = response.json()
54
+ items = search_results.get('items', [])
55
+ content = "\n\n".join([f"{item.get('title', '')}\n{item.get('snippet', '')}" for item in items])
56
+ prompt = f"Summarize the following search results:\n\n{content}"
57
+ summary_response = invoke_llm(prompt)
58
+ return summary_response.content.strip()
59
+ except Exception as e:
60
+ logger.error(f"Error during Google Search API request: {e}")
61
+ return None
62
+
63
+ def extract_content_from_item(self, item: Dict) -> Optional[str]:
64
+ try:
65
+ snippet = item.get('snippet', '')
66
+ title = item.get('title', '')
67
+ return f"{title}\n{snippet}"
68
+ except Exception as e:
69
+ logger.error(f"Error extracting content from item: {e}")
70
+ return None
71
+
72
+ def calculate_moving_average(self, df: pd.DataFrame, window: int = 20) -> Optional[pd.Series]:
73
+ try:
74
+ result = df['close'].rolling(window=window).mean()
75
+ return result
76
+ except Exception as e:
77
+ logger.error(f"Error calculating moving average: {e}")
78
+ return None
79
+
80
+ def calculate_rsi(self, df: pd.DataFrame, window: int = 14) -> Optional[pd.Series]:
81
+ try:
82
+ delta = df['close'].diff()
83
+ gain = delta.where(delta > 0, 0).rolling(window=window).mean()
84
+ loss = -delta.where(delta < 0, 0).rolling(window=window).mean()
85
+ rs = gain / loss
86
+ result = 100 - (100 / (1 + rs))
87
+ return result
88
+ except Exception as e:
89
+ logger.error(f"Error calculating RSI: {e}")
90
+ return None
91
+
92
+ def calculate_ema(self, df: pd.DataFrame, window: int = 20) -> Optional[pd.Series]:
93
+ try:
94
+ result = df['close'].ewm(span=window, adjust=False).mean()
95
+ return result
96
+ except Exception as e:
97
+ logger.error(f"Error calculating EMA: {e}")
98
+ return None
99
+
100
+ def calculate_bollinger_bands(self, df: pd.DataFrame, window: int = 20) -> Optional[pd.DataFrame]:
101
+ try:
102
+ ma = df['close'].rolling(window=window).mean()
103
+ std = df['close'].rolling(window=window).std()
104
+ upper_band = ma + (std * 2)
105
+ lower_band = ma - (std * 2)
106
+ result = pd.DataFrame({'MA': ma, 'Upper Band': upper_band, 'Lower Band': lower_band})
107
+ return result
108
+ except Exception as e:
109
+ logger.error(f"Error calculating Bollinger Bands: {e}")
110
+ return None
111
+
112
+ def calculate_macd(self, df: pd.DataFrame, short_window: int = 12, long_window: int = 26, signal_window: int = 9) -> Optional[pd.DataFrame]:
113
+ try:
114
+ short_ema = df['close'].ewm(span=short_window, adjust=False).mean()
115
+ long_ema = df['close'].ewm(span=long_window, adjust=False).mean()
116
+ macd = short_ema - long_ema
117
+ signal = macd.ewm(span=signal_window, adjust=False).mean()
118
+ result = pd.DataFrame({'MACD': macd, 'Signal Line': signal})
119
+ return result
120
+ except Exception as e:
121
+ logger.error(f"Error calculating MACD: {e}")
122
+ return None
123
+
124
+ def calculate_volatility(self, df: pd.DataFrame, window: int = 20) -> Optional[pd.Series]:
125
+ try:
126
+ log_returns = np.log(df['close'] / df['close'].shift(1))
127
+ result = log_returns.rolling(window=window).std() * np.sqrt(window)
128
+ return result
129
+ except Exception as e:
130
+ logger.error(f"Error calculating volatility: {e}")
131
+ return None
132
+
133
+ def calculate_atr(self, df: pd.DataFrame, window: int = 14) -> Optional[pd.Series]:
134
+ try:
135
+ high_low = df['high'] - df['low']
136
+ high_close = np.abs(df['high'] - df['close'].shift())
137
+ low_close = np.abs(df['low'] - df['close'].shift())
138
+ true_range = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
139
+ result = true_range.rolling(window=window).mean()
140
+ return result
141
+ except Exception as e:
142
+ logger.error(f"Error calculating ATR: {e}")
143
+ return None
144
+
145
+ def calculate_obv(self, df: pd.DataFrame) -> Optional[pd.Series]:
146
+ try:
147
+ result = (np.sign(df['close'].diff()) * df['volume']).fillna(0).cumsum()
148
+ return result
149
+ except Exception as e:
150
+ logger.error(f"Error calculating OBV: {e}")
151
+ return None
152
+
153
+ def calculate_yearly_summary(self, df: pd.DataFrame) -> Optional[pd.DataFrame]:
154
+ try:
155
+ df['year'] = pd.to_datetime(df['date']).dt.year
156
+ yearly_summary = df.groupby('year').agg({
157
+ 'close': ['mean', 'max', 'min'],
158
+ 'volume': 'sum'
159
+ })
160
+ yearly_summary.columns = ['_'.join(col) for col in yearly_summary.columns]
161
+ return yearly_summary
162
+ except Exception as e:
163
+ logger.error(f"Error calculating yearly summary: {e}")
164
+ return None
165
+
166
+ def get_full_last_year(self, df: pd.DataFrame) -> Optional[pd.DataFrame]:
167
+ try:
168
+ today = datetime.today().date()
169
+ last_year_start = datetime(today.year - 1, 1, 1).date()
170
+ last_year_end = datetime(today.year - 1, 12, 31).date()
171
+ mask = (df['date'] >= last_year_start) & (df['date'] <= last_year_end)
172
+ result = df.loc[mask]
173
+ return result
174
+ except Exception as e:
175
+ logger.error(f"Error filtering data for the last year: {e}")
176
+ return None
177
+
178
+ def calculate_ytd_performance(self, df: pd.DataFrame) -> Optional[float]:
179
+ try:
180
+ today = datetime.today().date()
181
+ year_start = datetime(today.year, 1, 1).date()
182
+ mask = (df['date'] >= year_start) & (df['date'] <= today)
183
+ ytd_data = df.loc[mask]
184
+ opening_price = ytd_data.iloc[0]['open']
185
+ closing_price = ytd_data.iloc[-1]['close']
186
+ result = ((closing_price - opening_price) / opening_price) * 100
187
+ return result
188
+ except Exception as e:
189
+ logger.error(f"Error calculating YTD performance: {e}")
190
+ return None
191
+
192
+ def calculate_pe_ratio(self, current_price: float, eps: float) -> Optional[float]:
193
+ try:
194
+ if eps == 0:
195
+ raise ValueError("EPS cannot be zero for P/E ratio calculation.")
196
+ result = current_price / eps
197
+ return result
198
+ except Exception as e:
199
+ logger.error(f"Error calculating P/E ratio: {e}")
200
+ return None
201
+
202
+ def fetch_google_snippet(self, query: str) -> Optional[str]:
203
+ try:
204
+ search_url = f"https://www.google.com/search?q={query}"
205
+ headers = {
206
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.36"
207
+ }
208
+ response = requests.get(search_url, headers=headers)
209
+ soup = BeautifulSoup(response.text, 'html.parser')
210
+ snippet_classes = [
211
+ 'BNeawe iBp4i AP7Wnd',
212
+ 'BNeawe s3v9rd AP7Wnd',
213
+ 'BVG0Nb',
214
+ 'kno-rdesc'
215
+ ]
216
+ snippet = None
217
+ for cls in snippet_classes:
218
+ snippet = soup.find('div', class_=cls)
219
+ if snippet:
220
+ break
221
+ return snippet.get_text() if snippet else "Snippet not found."
222
+ except Exception as e:
223
+ logger.error(f"Error fetching Google snippet: {e}")
224
+ return None
225
+
226
+ def extract_ticker_from_response(response: str) -> Optional[str]:
227
+ try:
228
+ if "is **" in response and "**." in response:
229
+ return response.split("is **")[1].split("**.")[0].strip()
230
+ return response.strip()
231
+ except Exception as e:
232
+ logger.error(f"Error extracting ticker from response: {e}")
233
+ return None
234
+
235
+ def detect_translate_entity_and_ticker(query: str) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]:
236
+ try:
237
+ # Step 1: Detect Language
238
+ prompt = f"Detect the language for the following text: {query}"
239
+ response = invoke_llm(prompt)
240
+ detected_language = response.content.strip()
241
+
242
+ # Step 2: Translate to English (if necessary)
243
+ translated_query = query
244
+ if detected_language != "English":
245
+ prompt = f"Translate the following text to English: {query}"
246
+ response = invoke_llm(prompt)
247
+ translated_query = response.content.strip()
248
+
249
+ # Step 3: Detect Entity
250
+ prompt = f"Detect the entity in the following text that is a company name: {translated_query}"
251
+ response = invoke_llm(prompt)
252
+ detected_entity = response.content.strip()
253
+
254
+ if not detected_entity:
255
+ return detected_language, None, translated_query, None
256
+
257
+ # Step 4: Get Stock Ticker
258
+ prompt = f"What is the stock ticker symbol for the company {detected_entity}?"
259
+ response = invoke_llm(prompt)
260
+ stock_ticker = extract_ticker_from_response(response.content.strip())
261
+
262
+ if not stock_ticker:
263
+ return detected_language, detected_entity, translated_query, None
264
+
265
+ return detected_language, detected_entity, translated_query, stock_ticker
266
+ except Exception as e:
267
+ logger.error(f"Error in detecting, translating, or extracting entity and ticker: {e}")
268
+ return None, None, None, None
269
+
270
+ def fetch_stock_data_yahoo(symbol: str) -> pd.DataFrame:
271
+ try:
272
+ stock = yf.Ticker(symbol)
273
+ end_date = datetime.now()
274
+ start_date = end_date - timedelta(days=3 * 365)
275
+ historical_data = stock.history(start=start_date, end=end_date)
276
+ historical_data = historical_data.rename(columns={"Open": "open", "High": "high", "Low": "low", "Close": "close", "Volume": "volume"})
277
+ historical_data.reset_index(inplace=True)
278
+ historical_data['date'] = historical_data['Date'].dt.date
279
+ historical_data = historical_data.drop(columns=['Date'])
280
+ historical_data = historical_data[['date', 'open', 'high', 'low', 'close', 'volume']]
281
+ return historical_data
282
+ except Exception as e:
283
+ logger.error(f"Failed to fetch stock data for {symbol} from Yahoo Finance: {e}")
284
+ return pd.DataFrame()
285
+
286
+ def fetch_current_stock_price(symbol: str) -> Optional[float]:
287
+ try:
288
+ stock = yf.Ticker(symbol)
289
+ result = stock.info['currentPrice']
290
+ return result
291
+ except Exception as e:
292
+ logger.error(f"Failed to fetch current stock price for {symbol}: {e}")
293
+ return None
294
+
295
+ def format_stock_data_for_gemini(stock_data: pd.DataFrame) -> str:
296
+ try:
297
+ if stock_data.empty:
298
+ return "No historical data available."
299
+ formatted_data = "Historical stock data for the last three years:\n\n"
300
+ formatted_data += "Date | Open | High | Low | Close | Volume\n"
301
+ formatted_data += "------------------------------------------------------\n"
302
+ for index, row in stock_data.iterrows():
303
+ formatted_data += f"{row['date']} | {row['open']:.2f} | {row['high']:.2f} | {row['low']:.2f} | {row['close']:.2f} | {int(row['volume'])}\n"
304
+ return formatted_data
305
+ except Exception as e:
306
+ logger.error(f"Error formatting stock data for Gemini: {e}")
307
+ return "Error formatting stock data."
308
+
309
+ def fetch_company_info_yahoo(symbol: str) -> Dict:
310
+ try:
311
+ stock = yf.Ticker(symbol)
312
+ company_info = stock.info
313
+ return {
314
+ "name": company_info.get("longName", "N/A"),
315
+ "sector": company_info.get("sector", "N/A"),
316
+ "industry": company_info.get("industry", "N/A"),
317
+ "marketCap": company_info.get("marketCap", "N/A"),
318
+ "summary": company_info.get("longBusinessSummary", "N/A"),
319
+ "website": company_info.get("website", "N/A"),
320
+ "address": company_info.get("address1", "N/A"),
321
+ "city": company_info.get("city", "N/A"),
322
+ "state": company_info.get("state", "N/A"),
323
+ "country": company_info.get("country", "N/A"),
324
+ "phone": company_info.get("phone", "N/A")
325
+ }
326
+ except Exception as e:
327
+ logger.error(f"Error fetching company info for {symbol}: {e}")
328
+ return {"error": str(e)}
329
+
330
+ def format_company_info_for_gemini(company_info: Dict) -> str:
331
+ try:
332
+ if "error" in company_info:
333
+ return f"Error fetching company info: {company_info['error']}"
334
+ formatted_info = (f"\nCompany Information:\n"
335
+ f"Name: {company_info['name']}\n"
336
+ f"Sector: {company_info['sector']}\n"
337
+ f"Industry: {company_info['industry']}\n"
338
+ f"Market Cap: {company_info['marketCap']}\n"
339
+ f"Summary: {company_info['summary']}\n"
340
+ f"Website: {company_info['website']}\n"
341
+ f"Address: {company_info['address']}, {company_info['city']}, {company_info['state']}, {company_info['country']}\n"
342
+ f"Phone: {company_info['phone']}\n")
343
+ return formatted_info
344
+ except Exception as e:
345
+ logger.error(f"Error formatting company info for Gemini: {e}")
346
+ return "Error formatting company info."
347
+
348
+ def fetch_company_news_yahoo(symbol: str) -> List[Dict]:
349
+ try:
350
+ stock = yf.Ticker(symbol)
351
+ news = stock.news
352
+ return news if news else []
353
+ except Exception as e:
354
+ logger.error(f"Failed to fetch news for {symbol} from Yahoo Finance: {e}")
355
+ return []
356
+
357
+ def format_company_news_for_gemini(news: List[Dict]) -> str:
358
+ try:
359
+ if not news:
360
+ return "No news available."
361
+ formatted_news = "Latest company news:\n\n"
362
+ for article in news:
363
+ formatted_news += (f"Title: {article['title']}\n"
364
+ f"Publisher: {article['publisher']}\n"
365
+ f"Link: {article['link']}\n"
366
+ f"Published: {article['providerPublishTime']}\n\n")
367
+ return formatted_news
368
+ except Exception as e:
369
+ logger.error(f"Error formatting company news for Gemini: {e}")
370
+ return "Error formatting company news."
371
+
372
+ def send_to_gemini_for_summarization(content: str) -> str:
373
+ try:
374
+ unified_content = " ".join(content)
375
+ prompt = f"Summarize the main points of this article.\n\n{unified_content}"
376
+ response = invoke_llm(prompt)
377
+ return response.content.strip()
378
+ except Exception as e:
379
+ logger.error(f"Error sending content to Gemini for summarization: {e}")
380
+ return "Error summarizing content."
381
+
382
+ def answer_question_with_data(question: str, data: Dict) -> str:
383
+ try:
384
+ data_str = ""
385
+ for key, value in data.items():
386
+ data_str += f"{key}:\n{value}\n\n"
387
+ prompt = (f"You are a financial advisor. Begin your answer and only give the answer after.\n"
388
+ f"Using the following data, answer this question: {question}\n\nData:\n{data_str}\n"
389
+ f"Make your answer in the best form and professional.\n"
390
+ f"Don't say anything about the source of the data.\n"
391
+ f"If you don't have the data to answer, say this data is not available yet. If the data is not available in the stock history data, say this was a weekend and there is no data for it.")
392
+ response = invoke_llm(prompt)
393
+ return response.content.strip()
394
+ except Exception as e:
395
+ logger.error(f"Error answering question with data: {e}")
396
+ return "Error answering question."
397
+
398
+ def calculate_metrics(stock_data: pd.DataFrame, summarizer: DataSummarizer, company_info: Dict) -> Dict[str, str]:
399
+ try:
400
+ moving_average = summarizer.calculate_moving_average(stock_data)
401
+ rsi = summarizer.calculate_rsi(stock_data)
402
+ ema = summarizer.calculate_ema(stock_data)
403
+ bollinger_bands = summarizer.calculate_bollinger_bands(stock_data)
404
+ macd = summarizer.calculate_macd(stock_data)
405
+ volatility = summarizer.calculate_volatility(stock_data)
406
+ atr = summarizer.calculate_atr(stock_data)
407
+ obv = summarizer.calculate_obv(stock_data)
408
+ yearly_summary = summarizer.calculate_yearly_summary(stock_data)
409
+ ytd_performance = summarizer.calculate_ytd_performance(stock_data)
410
+ eps = company_info.get('trailingEps', None)
411
+ if eps:
412
+ current_price = stock_data.iloc[-1]['close']
413
+ pe_ratio = summarizer.calculate_pe_ratio(current_price, eps)
414
+ formatted_metrics = {
415
+ "Moving Average": moving_average.to_string(),
416
+ "RSI": rsi.to_string(),
417
+ "EMA": ema.to_string(),
418
+ "Bollinger Bands": bollinger_bands.to_string(),
419
+ "MACD": macd.to_string(),
420
+ "Volatility": volatility.to_string(),
421
+ "ATR": atr.to_string(),
422
+ "OBV": obv.to_string(),
423
+ "Yearly Summary": yearly_summary.to_string(),
424
+ "YTD Performance": f"{ytd_performance:.2f}%",
425
+ "P/E Ratio": f"{pe_ratio:.2f}"
426
+ }
427
+ else:
428
+ formatted_metrics = {
429
+ "Moving Average": moving_average.to_string(),
430
+ "RSI": rsi.to_string(),
431
+ "EMA": ema.to_string(),
432
+ "Bollinger Bands": bollinger_bands.to_string(),
433
+ "MACD": macd.to_string(),
434
+ "Volatility": volatility.to_string(),
435
+ "ATR": atr.to_string(),
436
+ "OBV": obv.to_string(),
437
+ "Yearly Summary": yearly_summary.to_string(),
438
+ "YTD Performance": f"{ytd_performance:.2f}%"
439
+ }
440
+ return formatted_metrics
441
+ except Exception as e:
442
+ logger.error(f"Error calculating metrics: {e}")
443
+ return {"Error": "Error calculating metrics"}
444
+
445
+ def prepare_data(formatted_stock_data: str, formatted_company_info: str, formatted_company_news: str,
446
+ google_results: str, formatted_metrics: Dict[str, str], google_snippet: str, rag_response: str) -> Dict[str, str]:
447
+ collected_data = {
448
+ "Formatted Stock Data": formatted_stock_data,
449
+ "Formatted Company Info": formatted_company_info,
450
+ "Formatted Company News": formatted_company_news,
451
+ "Google Search Results": google_results,
452
+ "Google Snippet": google_snippet,
453
+ "RAG Response": rag_response,
454
+ "Calculations": formatted_metrics
455
+ }
456
+ collected_data.update(formatted_metrics)
457
+ return collected_data
458
+
459
+ @app.route('/ask', methods=['POST'])
460
+ def ask():
461
+ try:
462
+ user_input = request.json.get('question')
463
+ summarizer = DataSummarizer()
464
+ language, entity, translation, stock_ticker = detect_translate_entity_and_ticker(user_input)
465
+ if entity and stock_ticker:
466
+ with ThreadPoolExecutor() as executor:
467
+ futures = {
468
+ executor.submit(fetch_stock_data_yahoo, stock_ticker): "stock_data",
469
+ executor.submit(fetch_company_info_yahoo, stock_ticker): "company_info",
470
+ executor.submit(fetch_company_news_yahoo, stock_ticker): "company_news",
471
+ executor.submit(fetch_current_stock_price, stock_ticker): "current_stock_price",
472
+ executor.submit(get_answer, user_input): "rag_response",
473
+ executor.submit(summarizer.google_search, user_input): "google_results",
474
+ executor.submit(summarizer.fetch_google_snippet, user_input): "google_snippet"
475
+ }
476
+ results = {futures[future]: future.result() for future in as_completed(futures)}
477
+ stock_data = results.get("stock_data", pd.DataFrame())
478
+ formatted_stock_data = format_stock_data_for_gemini(stock_data) if not stock_data.empty else "No historical data available."
479
+ company_info = results.get("company_info", {})
480
+ formatted_company_info = format_company_info_for_gemini(company_info) if company_info else "No company info available."
481
+ company_news = results.get("company_news", [])
482
+ formatted_company_news = format_company_news_for_gemini(company_news) if company_news else "No news available."
483
+ current_stock_price = results.get("current_stock_price", None)
484
+ formatted_metrics = calculate_metrics(stock_data, summarizer, company_info) if not stock_data.empty else {"Error": "No stock data for metrics"}
485
+ google_results = results.get("google_results", "No additional news found through Google Search.")
486
+ google_snippet = results.get("google_snippet", "Snippet not found.")
487
+ rag_response = results.get("rag_response", "No response from RAG.")
488
+ collected_data = prepare_data(formatted_stock_data, formatted_company_info, formatted_company_news, google_results, formatted_metrics, google_snippet, rag_response)
489
+ collected_data["Current Stock Price"] = f"${current_stock_price:.2f}" if current_stock_price is not None else "N/A"
490
+ answer = answer_question_with_data(f"{translation}", collected_data)
491
+ return jsonify({"answer": answer})
492
+ else:
493
+ with ThreadPoolExecutor() as executor:
494
+ futures = {
495
+ executor.submit(get_answer, user_input): "rag_response",
496
+ executor.submit(summarizer.google_search, user_input): "google_results",
497
+ executor.submit(summarizer.fetch_google_snippet, user_input): "google_snippet"
498
+ }
499
+ results = {futures[future]: future.result() for future in as_completed(futures)}
500
+ google_results = results.get("google_results", "No additional news found through Google Search.")
501
+ google_snippet = results.get("google_snippet", "Snippet not found.")
502
+ rag_response = results.get("rag_response", "No response from RAG.")
503
+ collected_data = prepare_data("", "", "", google_results, {}, google_snippet, rag_response)
504
+ answer = answer_question_with_data(f"{user_input}", collected_data)
505
+ return jsonify({"answer": answer})
506
+ except Exception as e:
507
+ logger.error(f"An error occurred: {e}")
508
+ return jsonify({"error": "An error occurred while processing your request. Please try again later."}), 500
509
+
510
+ # Streamlit App
511
+ def send_question_to_api(question):
512
+ url = 'http://localhost:5000/ask'
513
+ headers = {'Content-Type': 'application/json'}
514
+ data = {'question': question}
515
+ response = requests.post(url, headers=headers, data=json.dumps(data))
516
+ if response.status_code == 200:
517
+ return response.json().get('answer')
518
+ else:
519
+ return f"Error: {response.status_code} - {response.text}"
520
+
521
+ def run_streamlit():
522
+ st.title("Financial Data Chatbot Tester")
523
+ st.write("Enter your question below and get a response from the chatbot.")
524
+ if 'history' not in st.session_state:
525
+ st.session_state.history = []
526
+ user_input = st.text_input("Your question:", "")
527
+ if st.button("Submit"):
528
+ if user_input:
529
+ with st.spinner('Getting the answer...'):
530
+ answer = send_question_to_api(user_input)
531
+ st.session_state.history.append((user_input, answer))
532
+ st.success(answer)
533
+ else:
534
+ st.warning("Please enter a question before submitting.")
535
+ if st.session_state.history:
536
+ st.write("### History")
537
+ for idx, (question, answer) in enumerate(st.session_state.history, 1):
538
+ st.write(f"**Q{idx}:** {question}")
539
+ st.write(f"**A{idx}:** {answer}")
540
+ st.write("---")
541
+
542
+ if __name__ == '__main__':
543
+ threading.Thread(target=lambda: app.run(host='0.0.0.0', port=5000)).start()
544
+ run_streamlit()
app2.py ADDED
@@ -0,0 +1,560 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from bs4 import BeautifulSoup
3
+ import yfinance as yf
4
+ import pandas as pd
5
+ from datetime import datetime, timedelta
6
+ import logging
7
+ from concurrent.futures import ThreadPoolExecutor, as_completed
8
+ from langchain_google_genai import ChatGoogleGenerativeAI
9
+ from config import Config
10
+ import numpy as np
11
+ from typing import Optional, Tuple, List, Dict
12
+ from rag import get_answer
13
+
14
+ # Set up logging
15
+ logging.basicConfig(level=logging.DEBUG,
16
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
17
+ handlers=[logging.FileHandler("app.log"),
18
+ logging.StreamHandler()])
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ # Initialize the Gemini model
23
+ llm = ChatGoogleGenerativeAI(api_key=Config.GEMINI_API_KEY, model="gemini-1.5-flash-latest", temperature=0.5)
24
+
25
+ # Configuration for Google Custom Search API
26
+ GOOGLE_API_KEY = Config.GOOGLE_API_KEY
27
+ SEARCH_ENGINE_ID = Config.SEARCH_ENGINE_ID
28
+
29
+ def fetch_google_snippet(query: str) -> Optional[str]:
30
+ try:
31
+ search_url = f"https://www.google.com/search?q={query}"
32
+ headers = {
33
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.36"
34
+ }
35
+ response = requests.get(search_url, headers=headers)
36
+ soup = BeautifulSoup(response.text, 'html.parser')
37
+ snippet_classes = [
38
+ 'BNeawe iBp4i AP7Wnd',
39
+ 'BNeawe s3v9rd AP7Wnd',
40
+ 'BVG0Nb',
41
+ 'kno-rdesc'
42
+ ]
43
+ for cls in snippet_classes:
44
+ snippet = soup.find('div', class_=cls)
45
+ if snippet:
46
+ return snippet.get_text()
47
+ return "Snippet not found."
48
+ except Exception as e:
49
+ logger.error(f"Error fetching Google snippet: {e}")
50
+ return None
51
+
52
+ class DataSummarizer:
53
+ def __init__(self):
54
+ pass
55
+
56
+ def google_search(self, query: str) -> Optional[Dict]:
57
+ try:
58
+ url = "https://www.googleapis.com/customsearch/v1"
59
+ params = {
60
+ 'key': GOOGLE_API_KEY,
61
+ 'cx': SEARCH_ENGINE_ID,
62
+ 'q': query
63
+ }
64
+ response = requests.get(url, params=params)
65
+ response.raise_for_status()
66
+ return response.json()
67
+ except Exception as e:
68
+ logger.error(f"Error during Google Search API request: {e}")
69
+ return None
70
+
71
+ def extract_content_from_item(self, item: Dict) -> Optional[str]:
72
+ try:
73
+ snippet = item.get('snippet', '')
74
+ title = item.get('title', '')
75
+ return f"{title}\n{snippet}"
76
+ except Exception as e:
77
+ logger.error(f"Error extracting content from item: {e}")
78
+ return None
79
+
80
+ def calculate_moving_average(self, df: pd.DataFrame, window: int = 20) -> Optional[pd.Series]:
81
+ try:
82
+ return df['close'].rolling(window=window).mean()
83
+ except Exception as e:
84
+ logger.error(f"Error calculating moving average: {e}")
85
+ return None
86
+
87
+ def calculate_rsi(self, df: pd.DataFrame, window: int = 14) -> Optional[pd.Series]:
88
+ try:
89
+ delta = df['close'].diff()
90
+ gain = delta.where(delta > 0, 0).rolling(window=window).mean()
91
+ loss = -delta.where(delta < 0, 0).rolling(window=window).mean()
92
+ rs = gain / loss
93
+ return 100 - (100 / (1 + rs))
94
+ except Exception as e:
95
+ logger.error(f"Error calculating RSI: {e}")
96
+ return None
97
+
98
+ def calculate_ema(self, df: pd.DataFrame, window: int = 20) -> Optional[pd.Series]:
99
+ try:
100
+ return df['close'].ewm(span=window, adjust=False).mean()
101
+ except Exception as e:
102
+ logger.error(f"Error calculating EMA: {e}")
103
+ return None
104
+
105
+ def calculate_bollinger_bands(self, df: pd.DataFrame, window: int = 20) -> Optional[pd.DataFrame]:
106
+ try:
107
+ ma = df['close'].rolling(window=window).mean()
108
+ std = df['close'].rolling(window=window).std()
109
+ upper_band = ma + (std * 2)
110
+ lower_band = ma - (std * 2)
111
+ return pd.DataFrame({'MA': ma, 'Upper Band': upper_band, 'Lower Band': lower_band})
112
+ except Exception as e:
113
+ logger.error(f"Error calculating Bollinger Bands: {e}")
114
+ return None
115
+
116
+ def calculate_macd(self, df: pd.DataFrame, short_window: int = 12, long_window: int = 26, signal_window: int = 9) -> \
117
+ Optional[pd.DataFrame]:
118
+ try:
119
+ short_ema = df['close'].ewm(span=short_window, adjust=False).mean()
120
+ long_ema = df['close'].ewm(span=long_window, adjust=False).mean()
121
+ macd = short_ema - long_ema
122
+ signal = macd.ewm(span=signal_window, adjust=False).mean()
123
+ return pd.DataFrame({'MACD': macd, 'Signal Line': signal})
124
+ except Exception as e:
125
+ logger.error(f"Error calculating MACD: {e}")
126
+ return None
127
+
128
+ def calculate_volatility(self, df: pd.DataFrame, window: int = 20) -> Optional[pd.Series]:
129
+ try:
130
+ log_returns = np.log(df['close'] / df['close'].shift(1))
131
+ return log_returns.rolling(window=window).std() * np.sqrt(window)
132
+ except Exception as e:
133
+ logger.error(f"Error calculating volatility: {e}")
134
+ return None
135
+
136
+ def calculate_atr(self, df: pd.DataFrame, window: int = 14) -> Optional[pd.Series]:
137
+ try:
138
+ high_low = df['high'] - df['low']
139
+ high_close = np.abs(df['high'] - df['close'].shift())
140
+ low_close = np.abs(df['low'] - df['close'].shift())
141
+ true_range = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
142
+ return true_range.rolling(window=window).mean()
143
+ except Exception as e:
144
+ logger.error(f"Error calculating ATR: {e}")
145
+ return None
146
+
147
+ def calculate_obv(self, df: pd.DataFrame) -> Optional[pd.Series]:
148
+ try:
149
+ return (np.sign(df['close'].diff()) * df['volume']).fillna(0).cumsum()
150
+ except Exception as e:
151
+ logger.error(f"Error calculating OBV: {e}")
152
+ return None
153
+
154
+ def calculate_yearly_summary(self, df: pd.DataFrame) -> Optional[pd.DataFrame]:
155
+ try:
156
+ df['year'] = pd.to_datetime(df['date']).dt.year
157
+ yearly_summary = df.groupby('year').agg({
158
+ 'close': ['mean', 'max', 'min'],
159
+ 'volume': 'sum'
160
+ })
161
+ yearly_summary.columns = ['_'.join(col) for col in yearly_summary.columns]
162
+ return yearly_summary
163
+ except Exception as e:
164
+ logger.error(f"Error calculating yearly summary: {e}")
165
+ return None
166
+
167
+ def get_full_last_year(self, df: pd.DataFrame) -> Optional[pd.DataFrame]:
168
+ try:
169
+ today = datetime.today().date()
170
+ last_year_start = datetime(today.year - 1, 1, 1).date()
171
+ last_year_end = datetime(today.year - 1, 12, 31).date()
172
+ mask = (df['date'] >= last_year_start) & (df['date'] <= last_year_end)
173
+ return df.loc[mask]
174
+ except Exception as e:
175
+ logger.error(f"Error filtering data for the last year: {e}")
176
+ return None
177
+
178
+ def calculate_ytd_performance(self, df: pd.DataFrame) -> Optional[float]:
179
+ try:
180
+ today = datetime.today().date()
181
+ year_start = datetime(today.year, 1, 1).date()
182
+ mask = (df['date'] >= year_start) & (df['date'] <= today)
183
+ ytd_data = df.loc[mask]
184
+ opening_price = ytd_data.iloc[0]['open']
185
+ closing_price = ytd_data.iloc[-1]['close']
186
+ return ((closing_price - opening_price) / opening_price) * 100
187
+ except Exception as e:
188
+ logger.error(f"Error calculating YTD performance: {e}")
189
+ return None
190
+
191
+ def calculate_pe_ratio(self, current_price: float, eps: float) -> Optional[float]:
192
+ try:
193
+ if eps == 0:
194
+ raise ValueError("EPS cannot be zero for P/E ratio calculation.")
195
+ return current_price / eps
196
+ except Exception as e:
197
+ logger.error(f"Error calculating P/E ratio: {e}")
198
+ return None
199
+
200
+ def fetch_google_snippet(self, query: str) -> Optional[str]:
201
+ try:
202
+ return fetch_google_snippet(query)
203
+ except Exception as e:
204
+ logger.error(f"Error fetching Google snippet: {e}")
205
+ return None
206
+
207
+ def extract_ticker_from_response(response: str) -> Optional[str]:
208
+ try:
209
+ if "is **" in response and "**." in response:
210
+ return response.split("is **")[1].split("**.")[0].strip()
211
+ return response.strip()
212
+ except Exception as e:
213
+ logger.error(f"Error extracting ticker from response: {e}")
214
+ return None
215
+
216
+ def detect_translate_entity_and_ticker(query: str) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]:
217
+ try:
218
+ prompt = f"Detect the language for the following text: {query}"
219
+ response = llm.invoke(prompt)
220
+ detected_language = response.content.strip()
221
+
222
+ translated_query = query
223
+ if detected_language != "English":
224
+ prompt = f"Translate the following text to English: {query}"
225
+ response = llm.invoke(prompt)
226
+ translated_query = response.content.strip()
227
+
228
+ prompt = f"Detect the entity in the following text that is a company name: {translated_query}"
229
+ response = llm.invoke(prompt)
230
+ detected_entity = response.content.strip()
231
+
232
+ prompt = f"What is the stock ticker symbol for the company {detected_entity}?"
233
+ response = llm.invoke(prompt)
234
+ stock_ticker = extract_ticker_from_response(response.content.strip())
235
+
236
+ return detected_language, detected_entity, translated_query, stock_ticker
237
+ except Exception as e:
238
+ logger.error(f"Error in detecting, translating, or extracting entity and ticker: {e}")
239
+ return None, None, None, None
240
+
241
+ def fetch_stock_data_yahoo(symbol: str) -> pd.DataFrame:
242
+ try:
243
+ stock = yf.Ticker(symbol)
244
+ logger.info(f"Fetching data for symbol: {symbol}")
245
+
246
+ end_date = datetime.now()
247
+ start_date = end_date - timedelta(days=3 * 365)
248
+
249
+ historical_data = stock.history(start=start_date, end=end_date)
250
+ if historical_data.empty:
251
+ raise ValueError(f"No historical data found for symbol: {symbol}")
252
+
253
+ historical_data = historical_data.rename(
254
+ columns={"Open": "open", "High": "high", "Low": "low", "Close": "close", "Volume": "volume"}
255
+ )
256
+
257
+ historical_data.reset_index(inplace=True)
258
+ historical_data['date'] = historical_data['Date'].dt.date
259
+ historical_data = historical_data.drop(columns=['Date'])
260
+ historical_data = historical_data[['date', 'open', 'high', 'low', 'close', 'volume']]
261
+
262
+ if 'close' not in historical_data.columns:
263
+ raise KeyError("The historical data must contain a 'close' column.")
264
+
265
+ return historical_data
266
+ except Exception as e:
267
+ logger.error(f"Failed to fetch stock data for {symbol} from Yahoo Finance: {e}")
268
+ return pd.DataFrame()
269
+
270
+ def fetch_current_stock_price(symbol: str) -> Optional[float]:
271
+ try:
272
+ stock = yf.Ticker(symbol)
273
+ return stock.info['currentPrice']
274
+ except Exception as e:
275
+ logger.error(f"Failed to fetch current stock price for {symbol}: {e}")
276
+ return None
277
+
278
+ def format_stock_data_for_gemini(stock_data: pd.DataFrame) -> str:
279
+ try:
280
+ if stock_data.empty:
281
+ return "No historical data available."
282
+
283
+ formatted_data = "Historical stock data for the last three years:\n\n"
284
+ formatted_data += "Date | Open | High | Low | Close | Volume\n"
285
+ formatted_data += "------------------------------------------------------\n"
286
+
287
+ for index, row in stock_data.iterrows():
288
+ formatted_data += f"{row['date']} | {row['open']:.2f} | {row['high']:.2f} | {row['low']:.2f} | {row['close']:.2f} | {int(row['volume'])}\n"
289
+
290
+ return formatted_data
291
+ except Exception as e:
292
+ logger.error(f"Error formatting stock data for Gemini: {e}")
293
+ return "Error formatting stock data."
294
+
295
+ def fetch_company_info_yahoo(symbol: str) -> Dict:
296
+ try:
297
+ if not symbol:
298
+ return {"error": "Invalid symbol"}
299
+
300
+ stock = yf.Ticker(symbol)
301
+ company_info = stock.info
302
+ return {
303
+ "name": company_info.get("longName", "N/A"),
304
+ "sector": company_info.get("sector", "N/A"),
305
+ "industry": company_info.get("industry", "N/A"),
306
+ "marketCap": company_info.get("marketCap", "N/A"),
307
+ "summary": company_info.get("longBusinessSummary", "N/A"),
308
+ "website": company_info.get("website", "N/A"),
309
+ "address": company_info.get("address1", "N/A"),
310
+ "city": company_info.get("city", "N/A"),
311
+ "state": company_info.get("state", "N/A"),
312
+ "country": company_info.get("country", "N/A"),
313
+ "phone": company_info.get("phone", "N/A")
314
+ }
315
+ except Exception as e:
316
+ logger.error(f"Error fetching company info for {symbol}: {e}")
317
+ return {"error": str(e)}
318
+
319
+ def format_company_info_for_gemini(company_info: Dict) -> str:
320
+ try:
321
+ if "error" in company_info:
322
+ return f"Error fetching company info: {company_info['error']}"
323
+
324
+ formatted_info = (f"\nCompany Information:\n"
325
+ f"Name: {company_info['name']}\n"
326
+ f"Sector: {company_info['sector']}\n"
327
+ f"Industry: {company_info['industry']}\n"
328
+ f"Market Cap: {company_info['marketCap']}\n"
329
+ f"Summary: {company_info['summary']}\n"
330
+ f"Website: {company_info['website']}\n"
331
+ f"Address: {company_info['address']}, {company_info['city']}, {company_info['state']}, {company_info['country']}\n"
332
+ f"Phone: {company_info['phone']}\n")
333
+
334
+ return formatted_info
335
+ except Exception as e:
336
+ logger.error(f"Error formatting company info for Gemini: {e}")
337
+ return "Error formatting company info."
338
+
339
+ def fetch_company_news_yahoo(symbol: str) -> List[Dict]:
340
+ try:
341
+ stock = yf.Ticker(symbol)
342
+ news = stock.news
343
+ if not news:
344
+ raise ValueError(f"No news found for symbol: {symbol}")
345
+ return news
346
+ except Exception as e:
347
+ logger.error(f"Failed to fetch news for {symbol} from Yahoo Finance: {e}")
348
+ return []
349
+
350
+ def format_company_news_for_gemini(news: List[Dict]) -> str:
351
+ try:
352
+ if not news:
353
+ return "No news available."
354
+
355
+ formatted_news = "Latest company news:\n\n"
356
+ for article in news:
357
+ formatted_news += (f"Title: {article['title']}\n"
358
+ f"Publisher: {article['publisher']}\n"
359
+ f"Link: {article['link']}\n"
360
+ f"Published: {article['providerPublishTime']}\n\n")
361
+
362
+ return formatted_news
363
+ except Exception as e:
364
+ logger.error(f"Error formatting company news for Gemini: {e}")
365
+ return "Error formatting company news."
366
+
367
+ def send_to_gemini_for_summarization(content: str) -> str:
368
+ try:
369
+ unified_content = " ".join(content)
370
+ prompt = f"Summarize the main points of this article.\n\n{unified_content}"
371
+ response = llm.invoke(prompt)
372
+ return response.content.strip()
373
+ except Exception as e:
374
+ logger.error(f"Error sending content to Gemini for summarization: {e}")
375
+ return "Error summarizing content."
376
+
377
+ def answer_question_with_data(question: str, data: Dict) -> str:
378
+ try:
379
+ data_str = ""
380
+ for key, value in data.items():
381
+ data_str += f"{key}:\n{value}\n\n"
382
+
383
+ prompt = (f"You are a financial advisor. Begin your answer by stating that and only give the answer after.\n"
384
+ f"Using the following data, answer this question: {question}\n\nData:\n{data_str}\n"
385
+ f"Make your answer in the best form and professional.\n"
386
+ f"Don't say anything about the source of the data.\n"
387
+ f"If you don't have the data to answer, say this data is not available yet. If the data is not available in the stock history data, say this was a weekend and there is no data for it.")
388
+ response = llm.invoke(prompt)
389
+ return response.content.strip()
390
+ except Exception as e:
391
+ logger.error(f"Error answering question with data: {e}")
392
+ return "Error answering question."
393
+
394
+ def format_google_results(google_results: Optional[Dict], summarizer: DataSummarizer, query: str) -> str:
395
+ try:
396
+ if google_results:
397
+ google_content = [summarizer.extract_content_from_item(item) for item in google_results.get('items', [])]
398
+ formatted_google_content = "\n\n".join(google_content)
399
+ else:
400
+ formatted_google_content = "No additional news found through Google Search."
401
+
402
+ snippet_query1 = f"{query} I want the answer only"
403
+ snippet_query2 = f"{query}"
404
+
405
+ google_snippet1 = summarizer.fetch_google_snippet(snippet_query1)
406
+ google_snippet2 = summarizer.fetch_google_snippet(snippet_query2)
407
+
408
+ google_snippet = google_snippet1 if google_snippet1 and google_snippet1 != "Snippet not found." else google_snippet2
409
+ formatted_google_content += f"\n\nGoogle Snippet: {google_snippet}"
410
+
411
+ return formatted_google_content
412
+ except Exception as e:
413
+ logger.error(f"Error formatting Google results: {e}")
414
+ return "Error formatting Google results."
415
+
416
+ def calculate_metrics(stock_data: pd.DataFrame, summarizer: DataSummarizer, company_info: Dict) -> Dict[str, str]:
417
+ try:
418
+ moving_average = summarizer.calculate_moving_average(stock_data)
419
+ rsi = summarizer.calculate_rsi(stock_data)
420
+ ema = summarizer.calculate_ema(stock_data)
421
+ bollinger_bands = summarizer.calculate_bollinger_bands(stock_data)
422
+ macd = summarizer.calculate_macd(stock_data)
423
+ volatility = summarizer.calculate_volatility(stock_data)
424
+ atr = summarizer.calculate_atr(stock_data)
425
+ obv = summarizer.calculate_obv(stock_data)
426
+ yearly_summary = summarizer.calculate_yearly_summary(stock_data)
427
+ ytd_performance = summarizer.calculate_ytd_performance(stock_data)
428
+
429
+ eps = company_info.get('trailingEps', None)
430
+ if eps:
431
+ current_price = stock_data.iloc[-1]['close']
432
+ pe_ratio = summarizer.calculate_pe_ratio(current_price, eps)
433
+ formatted_metrics = {
434
+ "Moving Average": moving_average.to_string(),
435
+ "RSI": rsi.to_string(),
436
+ "EMA": ema.to_string(),
437
+ "Bollinger Bands": bollinger_bands.to_string(),
438
+ "MACD": macd.to_string(),
439
+ "Volatility": volatility.to_string(),
440
+ "ATR": atr.to_string(),
441
+ "OBV": obv.to_string(),
442
+ "Yearly Summary": yearly_summary.to_string(),
443
+ "YTD Performance": f"{ytd_performance:.2f}%",
444
+ "P/E Ratio": f"{pe_ratio:.2f}"
445
+ }
446
+ else:
447
+ formatted_metrics = {
448
+ "Moving Average": moving_average.to_string(),
449
+ "RSI": rsi.to_string(),
450
+ "EMA": ema.to_string(),
451
+ "Bollinger Bands": bollinger_bands.to_string(),
452
+ "MACD": macd.to_string(),
453
+ "Volatility": volatility.to_string(),
454
+ "ATR": atr.to_string(),
455
+ "OBV": obv.to_string(),
456
+ "Yearly Summary": yearly_summary.to_string(),
457
+ "YTD Performance": f"{ytd_performance:.2f}%"
458
+ }
459
+
460
+ return formatted_metrics
461
+ except Exception as e:
462
+ logger.error(f"Error calculating metrics: {e}")
463
+ return {"Error": "Error calculating metrics"}
464
+
465
+ def prepare_data(formatted_stock_data: str, formatted_company_info: str, formatted_company_news: str,
466
+ summarized_google_content: str, formatted_metrics: Dict[str, str]) -> Dict[str, str]:
467
+ collected_data = {
468
+ "Formatted Stock Data": formatted_stock_data,
469
+ "Formatted Company Info": formatted_company_info,
470
+ "Formatted Company News": formatted_company_news,
471
+ "Google Search Results": summarized_google_content,
472
+ "Calculations": formatted_metrics
473
+ }
474
+ collected_data.update(formatted_metrics)
475
+ return collected_data
476
+
477
+ def translate_response(response: str, target_language: str) -> str:
478
+ try:
479
+ prompt = f"Translate the following text to {target_language}: {response}"
480
+ translation = llm.invoke(prompt)
481
+ return translation.content.strip()
482
+ except Exception as e:
483
+ logger.error(f"Error translating response: {e}")
484
+ return response # Return the original response if translation fails
485
+
486
+ def main():
487
+ print("Welcome to the Financial Data Chatbot. How can I assist you today?")
488
+
489
+ summarizer = DataSummarizer()
490
+ conversation_history = []
491
+
492
+ while True:
493
+ user_input = input("You: ")
494
+
495
+ if user_input.lower() in ['exit', 'quit', 'bye']:
496
+ print("Goodbye! Have a great day!")
497
+ break
498
+
499
+ conversation_history.append(f"You: {user_input}")
500
+
501
+ try:
502
+ # Detect language, entity, translation, and stock ticker
503
+ language, entity, translation, stock_ticker = detect_translate_entity_and_ticker(user_input)
504
+
505
+ if language and entity and translation and stock_ticker:
506
+ with ThreadPoolExecutor() as executor:
507
+ futures = {
508
+ executor.submit(fetch_stock_data_yahoo, stock_ticker): "stock_data",
509
+ executor.submit(fetch_company_info_yahoo, stock_ticker): "company_info",
510
+ executor.submit(fetch_company_news_yahoo, stock_ticker): "company_news",
511
+ executor.submit(fetch_current_stock_price, stock_ticker): "current_stock_price",
512
+ executor.submit(summarizer.google_search, f"{user_input} latest financial news"): "google_results"
513
+ }
514
+ results = {futures[future]: future.result() for future in as_completed(futures)}
515
+
516
+ stock_data = results["stock_data"]
517
+ formatted_stock_data = format_stock_data_for_gemini(stock_data)
518
+ company_info = results["company_info"]
519
+ formatted_company_info = format_company_info_for_gemini(company_info)
520
+ company_news = results["company_news"]
521
+ formatted_company_news = format_company_news_for_gemini(company_news)
522
+ current_stock_price = results["current_stock_price"]
523
+
524
+ google_results = results["google_results"]
525
+ formatted_google_content = format_google_results(google_results, summarizer, user_input)
526
+ summarized_google_content = send_to_gemini_for_summarization(formatted_google_content)
527
+
528
+ formatted_metrics = calculate_metrics(stock_data, summarizer, company_info)
529
+
530
+ collected_data = prepare_data(formatted_stock_data, formatted_company_info, formatted_company_news,
531
+ summarized_google_content, formatted_metrics)
532
+ collected_data["Current Stock Price"] = f"${current_stock_price:.2f}" if current_stock_price else "N/A"
533
+
534
+ rag_response = get_answer(user_input)
535
+ collected_data["RAG Response"] = rag_response
536
+
537
+ conversation_history.append(f"RAG Response: {rag_response}")
538
+ history_context = "\n".join(conversation_history)
539
+
540
+ answer = answer_question_with_data(f"{history_context}\n\nUser's query: {user_input}", collected_data)
541
+
542
+ if language != "English":
543
+ answer = translate_response(answer, language)
544
+
545
+ print(f"\nBot: {answer}")
546
+ conversation_history.append(f"Bot: {answer}")
547
+
548
+ else:
549
+ response = "I'm sorry, I couldn't process your request. Could you please rephrase?"
550
+ print(f"Bot: {response}")
551
+ conversation_history.append(f"Bot: {response}")
552
+
553
+ except Exception as e:
554
+ logger.error(f"An error occurred: {e}")
555
+ response = "An error occurred while processing your request. Please try again later."
556
+ print(f"Bot: {response}")
557
+ conversation_history.append(f"Bot: {response}")
558
+
559
+ if __name__ == "__main__":
560
+ main()
bm25retriever.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df285be2ae20135ec5219dd34edf52abe2c630b6372f33f1502e48fd52042526
3
+ size 4215997
chain.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.prompts import (
2
+ SystemMessagePromptTemplate,
3
+ HumanMessagePromptTemplate,
4
+ ChatPromptTemplate,
5
+ MessagesPlaceholder
6
+ )
7
+ from langchain.chains import ConversationChain
8
+
9
+ class Chain:
10
+ def __init__(self, llm, history=None):
11
+ self.llm = llm
12
+ # self.chain = self.get_conversational_chain()
13
+ if history is not None:
14
+ self.history = history
15
+
16
+ def run_conversational_chain(self, prompt_template):
17
+
18
+ ans = self.llm.invoke(prompt_template).content
19
+
20
+ return ans
21
+
22
+ def get_chain_with_history(self):
23
+ system_msg_template = SystemMessagePromptTemplate.from_template(template="""Answer the question as truthfully as possible using the provided context,
24
+ and if the answer is not contained within the text below, say 'I don't know'""")
25
+ human_msg_template = HumanMessagePromptTemplate.from_template(template="{input}")
26
+ prompt_template = ChatPromptTemplate.from_messages([system_msg_template, MessagesPlaceholder(variable_name="history"), human_msg_template])
27
+ conversation = ConversationChain(memory=self.history, prompt=prompt_template, llm=self.llm, verbose=True)
28
+ return conversation
chat.py ADDED
@@ -0,0 +1,667 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from bs4 import BeautifulSoup
3
+ import yfinance as yf
4
+ import pandas as pd
5
+ from datetime import datetime, timedelta
6
+ import logging
7
+ from concurrent.futures import ThreadPoolExecutor, as_completed
8
+ from langchain_google_genai import ChatGoogleGenerativeAI
9
+ from config import Config
10
+ import numpy as np
11
+ from typing import Optional, Tuple, List, Dict
12
+ from rag import get_answer
13
+ import time
14
+ from tenacity import retry, stop_after_attempt, wait_exponential
15
+
16
+ # Set up logging
17
+ logging.basicConfig(level=logging.DEBUG,
18
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
19
+ handlers=[logging.FileHandler("app.log"),
20
+ logging.StreamHandler()])
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ # Initialize the Gemini model
25
+ llm = ChatGoogleGenerativeAI(api_key=Config.GEMINI_API_KEY, model="gemini-1.5-flash-latest", temperature=0.5)
26
+
27
+ # Configuration for Google Custom Search API
28
+ GOOGLE_API_KEY = Config.GOOGLE_API_KEY
29
+ SEARCH_ENGINE_ID = Config.SEARCH_ENGINE_ID
30
+
31
+
32
+ @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=2, max=8), reraise=True)
33
+ def invoke_llm(prompt):
34
+ return llm.invoke(prompt)
35
+
36
+
37
+ class DataSummarizer:
38
+ def __init__(self):
39
+ pass
40
+
41
+ def google_search(self, query: str) -> Optional[str]:
42
+ start_time = time.time()
43
+ try:
44
+ url = "https://www.googleapis.com/customsearch/v1"
45
+ params = {
46
+ 'key': GOOGLE_API_KEY,
47
+ 'cx': SEARCH_ENGINE_ID,
48
+ 'q': query
49
+ }
50
+ response = requests.get(url, params=params)
51
+ response.raise_for_status()
52
+ search_results = response.json()
53
+ logger.info("google_search took %.2f seconds", time.time() - start_time)
54
+
55
+ # Summarize the search results using Gemini
56
+ items = search_results.get('items', [])
57
+ content = "\n\n".join([f"{item.get('title', '')}\n{item.get('snippet', '')}" for item in items])
58
+ prompt = f"Summarize the following search results:\n\n{content}"
59
+ summary_response = invoke_llm(prompt)
60
+ return summary_response.content.strip()
61
+ except Exception as e:
62
+ logger.error(f"Error during Google Search API request: {e}")
63
+ return None
64
+
65
+ def extract_content_from_item(self, item: Dict) -> Optional[str]:
66
+ try:
67
+ snippet = item.get('snippet', '')
68
+ title = item.get('title', '')
69
+ return f"{title}\n{snippet}"
70
+ except Exception as e:
71
+ logger.error(f"Error extracting content from item: {e}")
72
+ return None
73
+
74
+ def calculate_moving_average(self, df: pd.DataFrame, window: int = 20) -> Optional[pd.Series]:
75
+ start_time = time.time()
76
+ try:
77
+ result = df['close'].rolling(window=window).mean()
78
+ logger.info("calculate_moving_average took %.2f seconds", time.time() - start_time)
79
+ return result
80
+ except Exception as e:
81
+ logger.error(f"Error calculating moving average: {e}")
82
+ return None
83
+
84
+ def calculate_rsi(self, df: pd.DataFrame, window: int = 14) -> Optional[pd.Series]:
85
+ start_time = time.time()
86
+ try:
87
+ delta = df['close'].diff()
88
+ gain = delta.where(delta > 0, 0).rolling(window=window).mean()
89
+ loss = -delta.where(delta < 0, 0).rolling(window=window).mean()
90
+ rs = gain / loss
91
+ result = 100 - (100 / (1 + rs))
92
+ logger.info("calculate_rsi took %.2f seconds", time.time() - start_time)
93
+ return result
94
+ except Exception as e:
95
+ logger.error(f"Error calculating RSI: {e}")
96
+ return None
97
+
98
+ def calculate_ema(self, df: pd.DataFrame, window: int = 20) -> Optional[pd.Series]:
99
+ start_time = time.time()
100
+ try:
101
+ result = df['close'].ewm(span=window, adjust=False).mean()
102
+ logger.info("calculate_ema took %.2f seconds", time.time() - start_time)
103
+ return result
104
+ except Exception as e:
105
+ logger.error(f"Error calculating EMA: {e}")
106
+ return None
107
+
108
+ def calculate_bollinger_bands(self, df: pd.DataFrame, window: int = 20) -> Optional[pd.DataFrame]:
109
+ start_time = time.time()
110
+ try:
111
+ ma = df['close'].rolling(window=window).mean()
112
+ std = df['close'].rolling(window=window).std()
113
+ upper_band = ma + (std * 2)
114
+ lower_band = ma - (std * 2)
115
+ result = pd.DataFrame({'MA': ma, 'Upper Band': upper_band, 'Lower Band': lower_band})
116
+ logger.info("calculate_bollinger_bands took %.2f seconds", time.time() - start_time)
117
+ return result
118
+ except Exception as e:
119
+ logger.error(f"Error calculating Bollinger Bands: {e}")
120
+ return None
121
+
122
+ def calculate_macd(self, df: pd.DataFrame, short_window: int = 12, long_window: int = 26, signal_window: int = 9) -> \
123
+ Optional[pd.DataFrame]:
124
+ start_time = time.time()
125
+ try:
126
+ short_ema = df['close'].ewm(span=short_window, adjust=False).mean()
127
+ long_ema = df['close'].ewm(span=long_window, adjust=False).mean()
128
+ macd = short_ema - long_ema
129
+ signal = macd.ewm(span=signal_window, adjust=False).mean()
130
+ result = pd.DataFrame({'MACD': macd, 'Signal Line': signal})
131
+ logger.info("calculate_macd took %.2f seconds", time.time() - start_time)
132
+ return result
133
+ except Exception as e:
134
+ logger.error(f"Error calculating MACD: {e}")
135
+ return None
136
+
137
+ def calculate_volatility(self, df: pd.DataFrame, window: int = 20) -> Optional[pd.Series]:
138
+ start_time = time.time()
139
+ try:
140
+ log_returns = np.log(df['close'] / df['close'].shift(1))
141
+ result = log_returns.rolling(window=window).std() * np.sqrt(window)
142
+ logger.info("calculate_volatility took %.2f seconds", time.time() - start_time)
143
+ return result
144
+ except Exception as e:
145
+ logger.error(f"Error calculating volatility: {e}")
146
+ return None
147
+
148
+ def calculate_atr(self, df: pd.DataFrame, window: int = 14) -> Optional[pd.Series]:
149
+ start_time = time.time()
150
+ try:
151
+ high_low = df['high'] - df['low']
152
+ high_close = np.abs(df['high'] - df['close'].shift())
153
+ low_close = np.abs(df['low'] - df['close'].shift())
154
+ true_range = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
155
+ result = true_range.rolling(window=window).mean()
156
+ logger.info("calculate_atr took %.2f seconds", time.time() - start_time)
157
+ return result
158
+ except Exception as e:
159
+ logger.error(f"Error calculating ATR: {e}")
160
+ return None
161
+
162
+ def calculate_obv(self, df: pd.DataFrame) -> Optional[pd.Series]:
163
+ start_time = time.time()
164
+ try:
165
+ result = (np.sign(df['close'].diff()) * df['volume']).fillna(0).cumsum()
166
+ logger.info("calculate_obv took %.2f seconds", time.time() - start_time)
167
+ return result
168
+ except Exception as e:
169
+ logger.error(f"Error calculating OBV: {e}")
170
+ return None
171
+
172
+ def calculate_yearly_summary(self, df: pd.DataFrame) -> Optional[pd.DataFrame]:
173
+ start_time = time.time()
174
+ try:
175
+ df['year'] = pd.to_datetime(df['date']).dt.year
176
+ yearly_summary = df.groupby('year').agg({
177
+ 'close': ['mean', 'max', 'min'],
178
+ 'volume': 'sum'
179
+ })
180
+ yearly_summary.columns = ['_'.join(col) for col in yearly_summary.columns]
181
+ logger.info("calculate_yearly_summary took %.2f seconds", time.time() - start_time)
182
+ return yearly_summary
183
+ except Exception as e:
184
+ logger.error(f"Error calculating yearly summary: {e}")
185
+ return None
186
+
187
+ def get_full_last_year(self, df: pd.DataFrame) -> Optional[pd.DataFrame]:
188
+ start_time = time.time()
189
+ try:
190
+ today = datetime.today().date()
191
+ last_year_start = datetime(today.year - 1, 1, 1).date()
192
+ last_year_end = datetime(today.year - 1, 12, 31).date()
193
+ mask = (df['date'] >= last_year_start) & (df['date'] <= last_year_end)
194
+ result = df.loc[mask]
195
+ logger.info("get_full_last_year took %.2f seconds", time.time() - start_time)
196
+ return result
197
+ except Exception as e:
198
+ logger.error(f"Error filtering data for the last year: {e}")
199
+ return None
200
+
201
+ def calculate_ytd_performance(self, df: pd.DataFrame) -> Optional[float]:
202
+ start_time = time.time()
203
+ try:
204
+ today = datetime.today().date()
205
+ year_start = datetime(today.year, 1, 1).date()
206
+ mask = (df['date'] >= year_start) & (df['date'] <= today)
207
+ ytd_data = df.loc[mask]
208
+ opening_price = ytd_data.iloc[0]['open']
209
+ closing_price = ytd_data.iloc[-1]['close']
210
+ result = ((closing_price - opening_price) / opening_price) * 100
211
+ logger.info("calculate_ytd_performance took %.2f seconds", time.time() - start_time)
212
+ return result
213
+ except Exception as e:
214
+ logger.error(f"Error calculating YTD performance: {e}")
215
+ return None
216
+
217
+ def calculate_pe_ratio(self, current_price: float, eps: float) -> Optional[float]:
218
+ start_time = time.time()
219
+ try:
220
+ if eps == 0:
221
+ raise ValueError("EPS cannot be zero for P/E ratio calculation.")
222
+ result = current_price / eps
223
+ logger.info("calculate_pe_ratio took %.2f seconds", time.time() - start_time)
224
+ return result
225
+ except Exception as e:
226
+ logger.error(f"Error calculating P/E ratio: {e}")
227
+ return None
228
+
229
+ def fetch_google_snippet(self, query: str) -> Optional[str]:
230
+ try:
231
+ search_url = f"https://www.google.com/search?q={query}"
232
+ headers = {
233
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.36"
234
+ }
235
+ response = requests.get(search_url, headers=headers)
236
+ soup = BeautifulSoup(response.text, 'html.parser')
237
+ snippet_classes = [
238
+ 'BNeawe iBp4i AP7Wnd',
239
+ 'BNeawe s3v9rd AP7Wnd',
240
+ 'BVG0Nb',
241
+ 'kno-rdesc'
242
+ ]
243
+ snippet = None
244
+ for cls in snippet_classes:
245
+ snippet = soup.find('div', class_=cls)
246
+ if snippet:
247
+ break
248
+ return snippet.get_text() if snippet else "Snippet not found."
249
+ except Exception as e:
250
+ logger.error(f"Error fetching Google snippet: {e}")
251
+ return None
252
+
253
+
254
+ def extract_ticker_from_response(response: str) -> Optional[str]:
255
+ start_time = time.time()
256
+ try:
257
+ if "is **" in response and "**." in response:
258
+ result = response.split("is **")[1].split("**.")[0].strip()
259
+ logger.info("extract_ticker_from_response took %.2f seconds", time.time() - start_time)
260
+ return result
261
+ result = response.strip()
262
+ logger.info("extract_ticker_from_response took %.2f seconds", time.time() - start_time)
263
+ return result
264
+ except Exception as e:
265
+ logger.error(f"Error extracting ticker from response: {e}")
266
+ return None
267
+
268
+
269
+ def detect_translate_entity_and_ticker(query: str) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]:
270
+ try:
271
+ start_time = time.time()
272
+
273
+ # Step 1: Detect Language
274
+ prompt = f"Detect the language for the following text: {query}"
275
+ response = invoke_llm(prompt)
276
+ detected_language = response.content.strip()
277
+ logger.info(f"Language detected: {detected_language}")
278
+
279
+ # Step 2: Translate to English (if necessary)
280
+ translated_query = query
281
+ if detected_language != "English":
282
+ prompt = f"Translate the following text to English: {query}"
283
+ response = invoke_llm(prompt)
284
+ translated_query = response.content.strip()
285
+ logger.info(f"Translation completed: {translated_query}")
286
+ print(f"Translation: {translated_query}")
287
+
288
+ # Step 3: Detect Entity
289
+ prompt = f"Detect the entity in the following text that is a company name: {translated_query}"
290
+ response = invoke_llm(prompt)
291
+ detected_entity = response.content.strip()
292
+ logger.info(f"Entity detected: {detected_entity}")
293
+ print(f"Entity: {detected_entity}")
294
+
295
+ if not detected_entity:
296
+ logger.error("No entity detected")
297
+ return detected_language, None, translated_query, None
298
+
299
+ # Step 4: Get Stock Ticker
300
+ prompt = f"What is the stock ticker symbol for the company {detected_entity}?"
301
+ response = invoke_llm(prompt)
302
+ stock_ticker = extract_ticker_from_response(response.content.strip())
303
+
304
+ if not stock_ticker:
305
+ logger.error("No stock ticker detected")
306
+ return detected_language, detected_entity, translated_query, None
307
+
308
+ logger.info("detect_translate_entity_and_ticker took %.2f seconds", time.time() - start_time)
309
+ return detected_language, detected_entity, translated_query, stock_ticker
310
+ except Exception as e:
311
+ logger.error(f"Error in detecting, translating, or extracting entity and ticker: {e}")
312
+ return None, None, None, None
313
+
314
+
315
+ def fetch_stock_data_yahoo(symbol: str) -> pd.DataFrame:
316
+ start_time = time.time()
317
+ try:
318
+ stock = yf.Ticker(symbol)
319
+ logger.info(f"Fetching data for symbol: {symbol}")
320
+
321
+ end_date = datetime.now()
322
+ start_date = end_date - timedelta(days=3 * 365)
323
+
324
+ historical_data = stock.history(start=start_date, end=end_date)
325
+ if historical_data.empty:
326
+ raise ValueError(f"No historical data found for symbol: {symbol}")
327
+
328
+ historical_data = historical_data.rename(
329
+ columns={"Open": "open", "High": "high", "Low": "low", "Close": "close", "Volume": "volume"}
330
+ )
331
+
332
+ historical_data.reset_index(inplace=True)
333
+ historical_data['date'] = historical_data['Date'].dt.date
334
+ historical_data = historical_data.drop(columns=['Date'])
335
+ historical_data = historical_data[['date', 'open', 'high', 'low', 'close', 'volume']]
336
+
337
+ if 'close' not in historical_data.columns:
338
+ raise KeyError("The historical data must contain a 'close' column.")
339
+
340
+ logger.info("fetch_stock_data_yahoo took %.2f seconds", time.time() - start_time)
341
+ return historical_data
342
+ except Exception as e:
343
+ logger.error(f"Failed to fetch stock data for {symbol} from Yahoo Finance: {e}")
344
+ return pd.DataFrame()
345
+
346
+
347
+ def fetch_current_stock_price(symbol: str) -> Optional[float]:
348
+ start_time = time.time()
349
+ try:
350
+ stock = yf.Ticker(symbol)
351
+ result = stock.info['currentPrice']
352
+ logger.info("fetch_current_stock_price took %.2f seconds", time.time() - start_time)
353
+ return result
354
+ except Exception as e:
355
+ logger.error(f"Failed to fetch current stock price for {symbol}: {e}")
356
+ return None
357
+
358
+
359
+ def format_stock_data_for_gemini(stock_data: pd.DataFrame) -> str:
360
+ start_time = time.time()
361
+ try:
362
+ if stock_data.empty:
363
+ return "No historical data available."
364
+
365
+ formatted_data = "Historical stock data for the last three years:\n\n"
366
+ formatted_data += "Date | Open | High | Low | Close | Volume\n"
367
+ formatted_data += "------------------------------------------------------\n"
368
+
369
+ for index, row in stock_data.iterrows():
370
+ formatted_data += f"{row['date']} | {row['open']:.2f} | {row['high']:.2f} | {row['low']:.2f} | {row['close']:.2f} | {int(row['volume'])}\n"
371
+
372
+ logger.info("format_stock_data_for_gemini took %.2f seconds", time.time() - start_time)
373
+ return formatted_data
374
+ except Exception as e:
375
+ logger.error(f"Error formatting stock data for Gemini: {e}")
376
+ return "Error formatting stock data."
377
+
378
+
379
+ def fetch_company_info_yahoo(symbol: str) -> Dict:
380
+ start_time = time.time()
381
+ try:
382
+ if not symbol:
383
+ return {"error": "Invalid symbol"}
384
+
385
+ stock = yf.Ticker(symbol)
386
+ company_info = stock.info
387
+ logger.info("fetch_company_info_yahoo took %.2f seconds", time.time() - start_time)
388
+ return {
389
+ "name": company_info.get("longName", "N/A"),
390
+ "sector": company_info.get("sector", "N/A"),
391
+ "industry": company_info.get("industry", "N/A"),
392
+ "marketCap": company_info.get("marketCap", "N/A"),
393
+ "summary": company_info.get("longBusinessSummary", "N/A"),
394
+ "website": company_info.get("website", "N/A"),
395
+ "address": company_info.get("address1", "N/A"),
396
+ "city": company_info.get("city", "N/A"),
397
+ "state": company_info.get("state", "N/A"),
398
+ "country": company_info.get("country", "N/A"),
399
+ "phone": company_info.get("phone", "N/A")
400
+ }
401
+ except Exception as e:
402
+ logger.error(f"Error fetching company info for {symbol}: {e}")
403
+ return {"error": str(e)}
404
+
405
+
406
+ def format_company_info_for_gemini(company_info: Dict) -> str:
407
+ start_time = time.time()
408
+ try:
409
+ if "error" in company_info:
410
+ return f"Error fetching company info: {company_info['error']}"
411
+
412
+ formatted_info = (f"\nCompany Information:\n"
413
+ f"Name: {company_info['name']}\n"
414
+ f"Sector: {company_info['sector']}\n"
415
+ f"Industry: {company_info['industry']}\n"
416
+ f"Market Cap: {company_info['marketCap']}\n"
417
+ f"Summary: {company_info['summary']}\n"
418
+ f"Website: {company_info['website']}\n"
419
+ f"Address: {company_info['address']}, {company_info['city']}, {company_info['state']}, {company_info['country']}\n"
420
+ f"Phone: {company_info['phone']}\n")
421
+
422
+ logger.info("format_company_info_for_gemini took %.2f seconds", time.time() - start_time)
423
+ return formatted_info
424
+ except Exception as e:
425
+ logger.error(f"Error formatting company info for Gemini: {e}")
426
+ return "Error formatting company info."
427
+
428
+
429
+ def fetch_company_news_yahoo(symbol: str) -> List[Dict]:
430
+ start_time = time.time()
431
+ try:
432
+ stock = yf.Ticker(symbol)
433
+ news = stock.news
434
+ if not news:
435
+ raise ValueError(f"No news found for symbol: {symbol}")
436
+ logger.info("fetch_company_news_yahoo took %.2f seconds", time.time() - start_time)
437
+ return news
438
+ except Exception as e:
439
+ logger.error(f"Failed to fetch news for {symbol} from Yahoo Finance: {e}")
440
+ return []
441
+
442
+
443
+ def format_company_news_for_gemini(news: List[Dict]) -> str:
444
+ start_time = time.time()
445
+ try:
446
+ if not news:
447
+ return "No news available."
448
+
449
+ formatted_news = "Latest company news:\n\n"
450
+ for article in news:
451
+ formatted_news += (f"Title: {article['title']}\n"
452
+ f"Publisher: {article['publisher']}\n"
453
+ f"Link: {article['link']}\n"
454
+ f"Published: {article['providerPublishTime']}\n\n")
455
+
456
+ logger.info("format_company_news_for_gemini took %.2f seconds", time.time() - start_time)
457
+ return formatted_news
458
+ except Exception as e:
459
+ logger.error(f"Error formatting company news for Gemini: {e}")
460
+ return "Error formatting company news."
461
+
462
+
463
+ def send_to_gemini_for_summarization(content: str) -> str:
464
+ start_time = time.time()
465
+ try:
466
+ unified_content = " ".join(content)
467
+ prompt = f"Summarize the main points of this article.\n\n{unified_content}"
468
+ response = invoke_llm(prompt)
469
+ logger.info("send_to_gemini_for_summarization took %.2f seconds", time.time() - start_time)
470
+ return response.content.strip()
471
+ except Exception as e:
472
+ logger.error(f"Error sending content to Gemini for summarization: {e}")
473
+ return "Error summarizing content."
474
+
475
+
476
+ def answer_question_with_data(question: str, data: Dict) -> str:
477
+ start_time = time.time()
478
+ try:
479
+ data_str = ""
480
+ for key, value in data.items():
481
+ data_str += f"{key}:\n{value}\n\n"
482
+
483
+ prompt = (f"You are a financial advisor. Begin your answer by stating that and only give the answer after.\n"
484
+ f"Using the following data, answer this question: {question}\n\nData:\n{data_str}\n"
485
+ f"Make your answer in the best form and professional.\n"
486
+ f"Don't say anything about the source of the data.\n"
487
+ f"If you don't have the data to answer, say this data is not available yet. If the data is not available in the stock history data, say this was a weekend and there is no data for it.")
488
+ response = invoke_llm(prompt)
489
+ logger.info("answer_question_with_data took %.2f seconds", time.time() - start_time)
490
+ return response.content.strip()
491
+ except Exception as e:
492
+ logger.error(f"Error answering question with data: {e}")
493
+ return "Error answering question."
494
+
495
+
496
+ def calculate_metrics(stock_data: pd.DataFrame, summarizer: DataSummarizer, company_info: Dict) -> Dict[str, str]:
497
+ start_time = time.time()
498
+ try:
499
+ moving_average = summarizer.calculate_moving_average(stock_data)
500
+ rsi = summarizer.calculate_rsi(stock_data)
501
+ ema = summarizer.calculate_ema(stock_data)
502
+ bollinger_bands = summarizer.calculate_bollinger_bands(stock_data)
503
+ macd = summarizer.calculate_macd(stock_data)
504
+ volatility = summarizer.calculate_volatility(stock_data)
505
+ atr = summarizer.calculate_atr(stock_data)
506
+ obv = summarizer.calculate_obv(stock_data)
507
+ yearly_summary = summarizer.calculate_yearly_summary(stock_data)
508
+ ytd_performance = summarizer.calculate_ytd_performance(stock_data)
509
+
510
+ eps = company_info.get('trailingEps', None)
511
+ if eps:
512
+ current_price = stock_data.iloc[-1]['close']
513
+ pe_ratio = summarizer.calculate_pe_ratio(current_price, eps)
514
+ formatted_metrics = {
515
+ "Moving Average": moving_average.to_string(),
516
+ "RSI": rsi.to_string(),
517
+ "EMA": ema.to_string(),
518
+ "Bollinger Bands": bollinger_bands.to_string(),
519
+ "MACD": macd.to_string(),
520
+ "Volatility": volatility.to_string(),
521
+ "ATR": atr.to_string(),
522
+ "OBV": obv.to_string(),
523
+ "Yearly Summary": yearly_summary.to_string(),
524
+ "YTD Performance": f"{ytd_performance:.2f}%",
525
+ "P/E Ratio": f"{pe_ratio:.2f}"
526
+ }
527
+ else:
528
+ formatted_metrics = {
529
+ "Moving Average": moving_average.to_string(),
530
+ "RSI": rsi.to_string(),
531
+ "EMA": ema.to_string(),
532
+ "Bollinger Bands": bollinger_bands.to_string(),
533
+ "MACD": macd.to_string(),
534
+ "Volatility": volatility.to_string(),
535
+ "ATR": atr.to_string(),
536
+ "OBV": obv.to_string(),
537
+ "Yearly Summary": yearly_summary.to_string(),
538
+ "YTD Performance": f"{ytd_performance:.2f}%"
539
+ }
540
+
541
+ logger.info("calculate_metrics took %.2f seconds", time.time() - start_time)
542
+ return formatted_metrics
543
+ except Exception as e:
544
+ logger.error(f"Error calculating metrics: {e}")
545
+ return {"Error": "Error calculating metrics"}
546
+
547
+
548
+ def prepare_data(formatted_stock_data: str, formatted_company_info: str, formatted_company_news: str,
549
+ google_results: str, formatted_metrics: Dict[str, str], google_snippet: str, rag_response: str) -> \
550
+ Dict[str, str]:
551
+ start_time = time.time()
552
+ collected_data = {
553
+ "Formatted Stock Data": formatted_stock_data,
554
+ "Formatted Company Info": formatted_company_info,
555
+ "Formatted Company News": formatted_company_news,
556
+ "Google Search Results": google_results,
557
+ "Google Snippet": google_snippet,
558
+ "RAG Response": rag_response,
559
+ "Calculations": formatted_metrics
560
+ }
561
+ collected_data.update(formatted_metrics)
562
+ logger.info("prepare_data took %.2f seconds", time.time() - start_time)
563
+ return collected_data
564
+
565
+
566
+ def main():
567
+ print("Welcome to the Financial Data Chatbot. How can I assist you today?")
568
+
569
+ summarizer = DataSummarizer()
570
+ conversation_history = []
571
+
572
+ while True:
573
+ user_input = input("You: ")
574
+
575
+ if user_input.lower() in ['exit', 'quit', 'bye']:
576
+ print("Goodbye! Have a great day!")
577
+ break
578
+
579
+ conversation_history.append(f"You: {user_input}")
580
+
581
+ try:
582
+ # Detect language, entity, translation, and stock ticker
583
+ language, entity, translation, stock_ticker = detect_translate_entity_and_ticker(user_input)
584
+
585
+ logger.info(
586
+ f"Detected Language: {language}, Entity: {entity}, Translation: {translation}, Stock Ticker: {stock_ticker}")
587
+
588
+ if entity and stock_ticker:
589
+ with ThreadPoolExecutor() as executor:
590
+ futures = {
591
+ executor.submit(fetch_stock_data_yahoo, stock_ticker): "stock_data",
592
+ executor.submit(fetch_company_info_yahoo, stock_ticker): "company_info",
593
+ executor.submit(fetch_company_news_yahoo, stock_ticker): "company_news",
594
+ executor.submit(fetch_current_stock_price, stock_ticker): "current_stock_price",
595
+ executor.submit(get_answer, user_input): "rag_response",
596
+ executor.submit(summarizer.google_search, user_input): "google_results",
597
+ executor.submit(summarizer.fetch_google_snippet, user_input): "google_snippet"
598
+ }
599
+ results = {futures[future]: future.result() for future in as_completed(futures)}
600
+
601
+ stock_data = results.get("stock_data", pd.DataFrame())
602
+ formatted_stock_data = format_stock_data_for_gemini(
603
+ stock_data) if not stock_data.empty else "No historical data available."
604
+
605
+ company_info = results.get("company_info", {})
606
+ formatted_company_info = format_company_info_for_gemini(
607
+ company_info) if company_info else "No company info available."
608
+
609
+ company_news = results.get("company_news", [])
610
+ formatted_company_news = format_company_news_for_gemini(
611
+ company_news) if company_news else "No news available."
612
+
613
+ current_stock_price = results.get("current_stock_price", None)
614
+
615
+ formatted_metrics = calculate_metrics(stock_data, summarizer,
616
+ company_info) if not stock_data.empty else {
617
+ "Error": "No stock data for metrics"}
618
+
619
+ google_results = results.get("google_results", "No additional news found through Google Search.")
620
+ google_snippet = results.get("google_snippet", "Snippet not found.")
621
+
622
+ rag_response = results.get("rag_response", "No response from RAG.")
623
+
624
+ collected_data = prepare_data(formatted_stock_data, formatted_company_info, formatted_company_news,
625
+ google_results, formatted_metrics, google_snippet, rag_response)
626
+ collected_data[
627
+ "Current Stock Price"] = f"${current_stock_price:.2f}" if current_stock_price is not None else "N/A"
628
+
629
+ conversation_history.append(f"RAG Response: {rag_response}")
630
+ history_context = "\n".join(conversation_history)
631
+
632
+ answer = answer_question_with_data(f"{history_context}\n\nUser's query: {translation}", collected_data)
633
+
634
+ print(f"\nBot: {answer}")
635
+ conversation_history.append(f"Bot: {answer}")
636
+
637
+ else:
638
+ with ThreadPoolExecutor() as executor:
639
+ futures = {
640
+ executor.submit(get_answer, user_input): "rag_response",
641
+ executor.submit(summarizer.google_search, user_input): "google_results",
642
+ executor.submit(summarizer.fetch_google_snippet, user_input): "google_snippet"
643
+ }
644
+ results = {futures[future]: future.result() for future in as_completed(futures)}
645
+
646
+ google_results = results.get("google_results", "No additional news found through Google Search.")
647
+ google_snippet = results.get("google_snippet", "Snippet not found.")
648
+ rag_response = results.get("rag_response", "No response from RAG.")
649
+
650
+ collected_data = prepare_data("", "", "", google_results, {}, google_snippet, rag_response)
651
+
652
+ conversation_history.append(f"RAG Response: {rag_response}")
653
+ history_context = "\n".join(conversation_history)
654
+
655
+ answer = answer_question_with_data(f"{history_context}\n\nUser's query: {user_input}", collected_data)
656
+
657
+ print(f"\nBot: {answer}")
658
+ conversation_history.append(f"Bot: {answer}")
659
+
660
+ except Exception as e:
661
+ logger.error(f"An error occurred: {e}")
662
+ response = "An error occurred while processing your request. Please try again later."
663
+ print(f"Bot: {response}")
664
+ conversation_history.append(f"Bot: {response}")
665
+
666
+ if __name__ == "__main__":
667
+ main()
chatflask.py ADDED
@@ -0,0 +1,646 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ import requests
3
+ from bs4 import BeautifulSoup
4
+ import yfinance as yf
5
+ import pandas as pd
6
+ from datetime import datetime, timedelta
7
+ import logging
8
+ from concurrent.futures import ThreadPoolExecutor, as_completed
9
+ from langchain_google_genai import ChatGoogleGenerativeAI
10
+ from config import Config
11
+ import numpy as np
12
+ from typing import Optional, Tuple, List, Dict
13
+ from rag import get_answer
14
+ import time
15
+ from tenacity import retry, stop_after_attempt, wait_exponential
16
+
17
+ # Initialize Flask app
18
+ app = Flask(__name__)
19
+
20
+ # Set up logging
21
+ logging.basicConfig(level=logging.DEBUG,
22
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
23
+ handlers=[logging.FileHandler("app.log"),
24
+ logging.StreamHandler()])
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+ # Initialize the Gemini model
29
+ llm = ChatGoogleGenerativeAI(api_key=Config.GEMINI_API_KEY, model="gemini-1.5-flash-latest", temperature=0.5)
30
+
31
+ # Configuration for Google Custom Search API
32
+ GOOGLE_API_KEY = Config.GOOGLE_API_KEY
33
+ SEARCH_ENGINE_ID = Config.SEARCH_ENGINE_ID
34
+
35
+
36
+ @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=2, max=8), reraise=True)
37
+ def invoke_llm(prompt):
38
+ return llm.invoke(prompt)
39
+
40
+
41
+ class DataSummarizer:
42
+ def __init__(self):
43
+ pass
44
+
45
+ def google_search(self, query: str) -> Optional[str]:
46
+ start_time = time.time()
47
+ try:
48
+ url = "https://www.googleapis.com/customsearch/v1"
49
+ params = {
50
+ 'key': GOOGLE_API_KEY,
51
+ 'cx': SEARCH_ENGINE_ID,
52
+ 'q': query
53
+ }
54
+ response = requests.get(url, params=params)
55
+ response.raise_for_status()
56
+ search_results = response.json()
57
+ logger.info("google_search took %.2f seconds", time.time() - start_time)
58
+
59
+ # Summarize the search results using Gemini
60
+ items = search_results.get('items', [])
61
+ content = "\n\n".join([f"{item.get('title', '')}\n{item.get('snippet', '')}" for item in items])
62
+ prompt = f"Summarize the following search results:\n\n{content}"
63
+ summary_response = invoke_llm(prompt)
64
+ return summary_response.content.strip()
65
+ except Exception as e:
66
+ logger.error(f"Error during Google Search API request: {e}")
67
+ return None
68
+
69
+ def extract_content_from_item(self, item: Dict) -> Optional[str]:
70
+ try:
71
+ snippet = item.get('snippet', '')
72
+ title = item.get('title', '')
73
+ return f"{title}\n{snippet}"
74
+ except Exception as e:
75
+ logger.error(f"Error extracting content from item: {e}")
76
+ return None
77
+
78
+ def calculate_moving_average(self, df: pd.DataFrame, window: int = 20) -> Optional[pd.Series]:
79
+ start_time = time.time()
80
+ try:
81
+ result = df['close'].rolling(window=window).mean()
82
+ logger.info("calculate_moving_average took %.2f seconds", time.time() - start_time)
83
+ return result
84
+ except Exception as e:
85
+ logger.error(f"Error calculating moving average: {e}")
86
+ return None
87
+
88
+ def calculate_rsi(self, df: pd.DataFrame, window: int = 14) -> Optional[pd.Series]:
89
+ start_time = time.time()
90
+ try:
91
+ delta = df['close'].diff()
92
+ gain = delta.where(delta > 0, 0).rolling(window=window).mean()
93
+ loss = -delta.where(delta < 0, 0).rolling(window=window).mean()
94
+ rs = gain / loss
95
+ result = 100 - (100 / (1 + rs))
96
+ logger.info("calculate_rsi took %.2f seconds", time.time() - start_time)
97
+ return result
98
+ except Exception as e:
99
+ logger.error(f"Error calculating RSI: {e}")
100
+ return None
101
+
102
+ def calculate_ema(self, df: pd.DataFrame, window: int = 20) -> Optional[pd.Series]:
103
+ start_time = time.time()
104
+ try:
105
+ result = df['close'].ewm(span=window, adjust=False).mean()
106
+ logger.info("calculate_ema took %.2f seconds", time.time() - start_time)
107
+ return result
108
+ except Exception as e:
109
+ logger.error(f"Error calculating EMA: {e}")
110
+ return None
111
+
112
+ def calculate_bollinger_bands(self, df: pd.DataFrame, window: int = 20) -> Optional[pd.DataFrame]:
113
+ start_time = time.time()
114
+ try:
115
+ ma = df['close'].rolling(window=window).mean()
116
+ std = df['close'].rolling(window=window).std()
117
+ upper_band = ma + (std * 2)
118
+ lower_band = ma - (std * 2)
119
+ result = pd.DataFrame({'MA': ma, 'Upper Band': upper_band, 'Lower Band': lower_band})
120
+ logger.info("calculate_bollinger_bands took %.2f seconds", time.time() - start_time)
121
+ return result
122
+ except Exception as e:
123
+ logger.error(f"Error calculating Bollinger Bands: {e}")
124
+ return None
125
+
126
+ def calculate_macd(self, df: pd.DataFrame, short_window: int = 12, long_window: int = 26, signal_window: int = 9) -> \
127
+ Optional[pd.DataFrame]:
128
+ start_time = time.time()
129
+ try:
130
+ short_ema = df['close'].ewm(span=short_window, adjust=False).mean()
131
+ long_ema = df['close'].ewm(span=long_window, adjust=False).mean()
132
+ macd = short_ema - long_ema
133
+ signal = macd.ewm(span=signal_window, adjust=False).mean()
134
+ result = pd.DataFrame({'MACD': macd, 'Signal Line': signal})
135
+ logger.info("calculate_macd took %.2f seconds", time.time() - start_time)
136
+ return result
137
+ except Exception as e:
138
+ logger.error(f"Error calculating MACD: {e}")
139
+ return None
140
+
141
+ def calculate_volatility(self, df: pd.DataFrame, window: int = 20) -> Optional[pd.Series]:
142
+ start_time = time.time()
143
+ try:
144
+ log_returns = np.log(df['close'] / df['close'].shift(1))
145
+ result = log_returns.rolling(window=window).std() * np.sqrt(window)
146
+ logger.info("calculate_volatility took %.2f seconds", time.time() - start_time)
147
+ return result
148
+ except Exception as e:
149
+ logger.error(f"Error calculating volatility: {e}")
150
+ return None
151
+
152
+ def calculate_atr(self, df: pd.DataFrame, window: int = 14) -> Optional[pd.Series]:
153
+ start_time = time.time()
154
+ try:
155
+ high_low = df['high'] - df['low']
156
+ high_close = np.abs(df['high'] - df['close'].shift())
157
+ low_close = np.abs(df['low'] - df['close'].shift())
158
+ true_range = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
159
+ result = true_range.rolling(window=window).mean()
160
+ logger.info("calculate_atr took %.2f seconds", time.time() - start_time)
161
+ return result
162
+ except Exception as e:
163
+ logger.error(f"Error calculating ATR: {e}")
164
+ return None
165
+
166
+ def calculate_obv(self, df: pd.DataFrame) -> Optional[pd.Series]:
167
+ start_time = time.time()
168
+ try:
169
+ result = (np.sign(df['close'].diff()) * df['volume']).fillna(0).cumsum()
170
+ logger.info("calculate_obv took %.2f seconds", time.time() - start_time)
171
+ return result
172
+ except Exception as e:
173
+ logger.error(f"Error calculating OBV: {e}")
174
+ return None
175
+
176
+ def calculate_yearly_summary(self, df: pd.DataFrame) -> Optional[pd.DataFrame]:
177
+ start_time = time.time()
178
+ try:
179
+ df['year'] = pd.to_datetime(df['date']).dt.year
180
+ yearly_summary = df.groupby('year').agg({
181
+ 'close': ['mean', 'max', 'min'],
182
+ 'volume': 'sum'
183
+ })
184
+ yearly_summary.columns = ['_'.join(col) for col in yearly_summary.columns]
185
+ logger.info("calculate_yearly_summary took %.2f seconds", time.time() - start_time)
186
+ return yearly_summary
187
+ except Exception as e:
188
+ logger.error(f"Error calculating yearly summary: {e}")
189
+ return None
190
+
191
+ def get_full_last_year(self, df: pd.DataFrame) -> Optional[pd.DataFrame]:
192
+ start_time = time.time()
193
+ try:
194
+ today = datetime.today().date()
195
+ last_year_start = datetime(today.year - 1, 1, 1).date()
196
+ last_year_end = datetime(today.year - 1, 12, 31).date()
197
+ mask = (df['date'] >= last_year_start) & (df['date'] <= last_year_end)
198
+ result = df.loc[mask]
199
+ logger.info("get_full_last_year took %.2f seconds", time.time() - start_time)
200
+ return result
201
+ except Exception as e:
202
+ logger.error(f"Error filtering data for the last year: {e}")
203
+ return None
204
+
205
+ def calculate_ytd_performance(self, df: pd.DataFrame) -> Optional[float]:
206
+ start_time = time.time()
207
+ try:
208
+ today = datetime.today().date()
209
+ year_start = datetime(today.year, 1, 1).date()
210
+ mask = (df['date'] >= year_start) & (df['date'] <= today)
211
+ ytd_data = df.loc[mask]
212
+ opening_price = ytd_data.iloc[0]['open']
213
+ closing_price = ytd_data.iloc[-1]['close']
214
+ result = ((closing_price - opening_price) / opening_price) * 100
215
+ logger.info("calculate_ytd_performance took %.2f seconds", time.time() - start_time)
216
+ return result
217
+ except Exception as e:
218
+ logger.error(f"Error calculating YTD performance: {e}")
219
+ return None
220
+
221
+ def calculate_pe_ratio(self, current_price: float, eps: float) -> Optional[float]:
222
+ start_time = time.time()
223
+ try:
224
+ if eps == 0:
225
+ raise ValueError("EPS cannot be zero for P/E ratio calculation.")
226
+ result = current_price / eps
227
+ logger.info("calculate_pe_ratio took %.2f seconds", time.time() - start_time)
228
+ return result
229
+ except Exception as e:
230
+ logger.error(f"Error calculating P/E ratio: {e}")
231
+ return None
232
+
233
+ def fetch_google_snippet(self, query: str) -> Optional[str]:
234
+ try:
235
+ search_url = f"https://www.google.com/search?q={query}"
236
+ headers = {
237
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.36"
238
+ }
239
+ response = requests.get(search_url, headers=headers)
240
+ soup = BeautifulSoup(response.text, 'html.parser')
241
+ snippet_classes = [
242
+ 'BNeawe iBp4i AP7Wnd',
243
+ 'BNeawe s3v9rd AP7Wnd',
244
+ 'BVG0Nb',
245
+ 'kno-rdesc'
246
+ ]
247
+ snippet = None
248
+ for cls in snippet_classes:
249
+ snippet = soup.find('div', class_=cls)
250
+ if snippet:
251
+ break
252
+ return snippet.get_text() if snippet else "Snippet not found."
253
+ except Exception as e:
254
+ logger.error(f"Error fetching Google snippet: {e}")
255
+ return None
256
+
257
+
258
+ def extract_ticker_from_response(response: str) -> Optional[str]:
259
+ start_time = time.time()
260
+ try:
261
+ if "is **" in response and "**." in response:
262
+ result = response.split("is **")[1].split("**.")[0].strip()
263
+ logger.info("extract_ticker_from_response took %.2f seconds", time.time() - start_time)
264
+ return result
265
+ result = response.strip()
266
+ logger.info("extract_ticker_from_response took %.2f seconds", time.time() - start_time)
267
+ return result
268
+ except Exception as e:
269
+ logger.error(f"Error extracting ticker from response: {e}")
270
+ return None
271
+
272
+
273
+ def detect_translate_entity_and_ticker(query: str) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]:
274
+ try:
275
+ start_time = time.time()
276
+
277
+ # Step 1: Detect Language
278
+ prompt = f"Detect the language for the following text: {query}"
279
+ response = invoke_llm(prompt)
280
+ detected_language = response.content.strip()
281
+ logger.info(f"Language detected: {detected_language}")
282
+
283
+ # Step 2: Translate to English (if necessary)
284
+ translated_query = query
285
+ if detected_language != "English":
286
+ prompt = f"Translate the following text to English: {query}"
287
+ response = invoke_llm(prompt)
288
+ translated_query = response.content.strip()
289
+ logger.info(f"Translation completed: {translated_query}")
290
+ print(f"Translation: {translated_query}")
291
+
292
+ # Step 3: Detect Entity
293
+ prompt = f"Detect the entity in the following text that is a company name: {translated_query}"
294
+ response = invoke_llm(prompt)
295
+ detected_entity = response.content.strip()
296
+ logger.info(f"Entity detected: {detected_entity}")
297
+ print(f"Entity: {detected_entity}")
298
+
299
+ if not detected_entity:
300
+ logger.error("No entity detected")
301
+ return detected_language, None, translated_query, None
302
+
303
+ # Step 4: Get Stock Ticker
304
+ prompt = f"What is the stock ticker symbol for the company {detected_entity}?"
305
+ response = invoke_llm(prompt)
306
+ stock_ticker = extract_ticker_from_response(response.content.strip())
307
+
308
+ if not stock_ticker:
309
+ logger.error("No stock ticker detected")
310
+ return detected_language, detected_entity, translated_query, None
311
+
312
+ logger.info("detect_translate_entity_and_ticker took %.2f seconds", time.time() - start_time)
313
+ return detected_language, detected_entity, translated_query, stock_ticker
314
+ except Exception as e:
315
+ logger.error(f"Error in detecting, translating, or extracting entity and ticker: {e}")
316
+ return None, None, None, None
317
+
318
+
319
+ def fetch_stock_data_yahoo(symbol: str) -> pd.DataFrame:
320
+ start_time = time.time()
321
+ try:
322
+ stock = yf.Ticker(symbol)
323
+ logger.info(f"Fetching data for symbol: {symbol}")
324
+
325
+ end_date = datetime.now()
326
+ start_date = end_date - timedelta(days=3 * 365)
327
+
328
+ historical_data = stock.history(start=start_date, end=end_date)
329
+ if historical_data.empty:
330
+ raise ValueError(f"No historical data found for symbol: {symbol}")
331
+
332
+ historical_data = historical_data.rename(
333
+ columns={"Open": "open", "High": "high", "Low": "low", "Close": "close", "Volume": "volume"}
334
+ )
335
+
336
+ historical_data.reset_index(inplace=True)
337
+ historical_data['date'] = historical_data['Date'].dt.date
338
+ historical_data = historical_data.drop(columns=['Date'])
339
+ historical_data = historical_data[['date', 'open', 'high', 'low', 'close', 'volume']]
340
+
341
+ if 'close' not in historical_data.columns:
342
+ raise KeyError("The historical data must contain a 'close' column.")
343
+
344
+ logger.info("fetch_stock_data_yahoo took %.2f seconds", time.time() - start_time)
345
+ return historical_data
346
+ except Exception as e:
347
+ logger.error(f"Failed to fetch stock data for {symbol} from Yahoo Finance: {e}")
348
+ return pd.DataFrame()
349
+
350
+
351
+ def fetch_current_stock_price(symbol: str) -> Optional[float]:
352
+ start_time = time.time()
353
+ try:
354
+ stock = yf.Ticker(symbol)
355
+ result = stock.info['currentPrice']
356
+ logger.info("fetch_current_stock_price took %.2f seconds", time.time() - start_time)
357
+ return result
358
+ except Exception as e:
359
+ logger.error(f"Failed to fetch current stock price for {symbol}: {e}")
360
+ return None
361
+
362
+
363
+ def format_stock_data_for_gemini(stock_data: pd.DataFrame) -> str:
364
+ start_time = time.time()
365
+ try:
366
+ if stock_data.empty:
367
+ return "No historical data available."
368
+
369
+ formatted_data = "Historical stock data for the last three years:\n\n"
370
+ formatted_data += "Date | Open | High | Low | Close | Volume\n"
371
+ formatted_data += "------------------------------------------------------\n"
372
+
373
+ for index, row in stock_data.iterrows():
374
+ formatted_data += f"{row['date']} | {row['open']:.2f} | {row['high']:.2f} | {row['low']:.2f} | {row['close']:.2f} | {int(row['volume'])}\n"
375
+
376
+ logger.info("format_stock_data_for_gemini took %.2f seconds", time.time() - start_time)
377
+ return formatted_data
378
+ except Exception as e:
379
+ logger.error(f"Error formatting stock data for Gemini: {e}")
380
+ return "Error formatting stock data."
381
+
382
+
383
+ def fetch_company_info_yahoo(symbol: str) -> Dict:
384
+ start_time = time.time()
385
+ try:
386
+ if not symbol:
387
+ return {"error": "Invalid symbol"}
388
+
389
+ stock = yf.Ticker(symbol)
390
+ company_info = stock.info
391
+ logger.info("fetch_company_info_yahoo took %.2f seconds", time.time() - start_time)
392
+ return {
393
+ "name": company_info.get("longName", "N/A"),
394
+ "sector": company_info.get("sector", "N/A"),
395
+ "industry": company_info.get("industry", "N/A"),
396
+ "marketCap": company_info.get("marketCap", "N/A"),
397
+ "summary": company_info.get("longBusinessSummary", "N/A"),
398
+ "website": company_info.get("website", "N/A"),
399
+ "address": company_info.get("address1", "N/A"),
400
+ "city": company_info.get("city", "N/A"),
401
+ "state": company_info.get("state", "N/A"),
402
+ "country": company_info.get("country", "N/A"),
403
+ "phone": company_info.get("phone", "N/A")
404
+ }
405
+ except Exception as e:
406
+ logger.error(f"Error fetching company info for {symbol}: {e}")
407
+ return {"error": str(e)}
408
+
409
+
410
+ def format_company_info_for_gemini(company_info: Dict) -> str:
411
+ start_time = time.time()
412
+ try:
413
+ if "error" in company_info:
414
+ return f"Error fetching company info: {company_info['error']}"
415
+
416
+ formatted_info = (f"\nCompany Information:\n"
417
+ f"Name: {company_info['name']}\n"
418
+ f"Sector: {company_info['sector']}\n"
419
+ f"Industry: {company_info['industry']}\n"
420
+ f"Market Cap: {company_info['marketCap']}\n"
421
+ f"Summary: {company_info['summary']}\n"
422
+ f"Website: {company_info['website']}\n"
423
+ f"Address: {company_info['address']}, {company_info['city']}, {company_info['state']}, {company_info['country']}\n"
424
+ f"Phone: {company_info['phone']}\n")
425
+
426
+ logger.info("format_company_info_for_gemini took %.2f seconds", time.time() - start_time)
427
+ return formatted_info
428
+ except Exception as e:
429
+ logger.error(f"Error formatting company info for Gemini: {e}")
430
+ return "Error formatting company info."
431
+
432
+
433
+ def fetch_company_news_yahoo(symbol: str) -> List[Dict]:
434
+ start_time = time.time()
435
+ try:
436
+ stock = yf.Ticker(symbol)
437
+ news = stock.news
438
+ if not news:
439
+ raise ValueError(f"No news found for symbol: {symbol}")
440
+ logger.info("fetch_company_news_yahoo took %.2f seconds", time.time() - start_time)
441
+ return news
442
+ except Exception as e:
443
+ logger.error(f"Failed to fetch news for {symbol} from Yahoo Finance: {e}")
444
+ return []
445
+
446
+
447
+ def format_company_news_for_gemini(news: List[Dict]) -> str:
448
+ start_time = time.time()
449
+ try:
450
+ if not news:
451
+ return "No news available."
452
+
453
+ formatted_news = "Latest company news:\n\n"
454
+ for article in news:
455
+ formatted_news += (f"Title: {article['title']}\n"
456
+ f"Publisher: {article['publisher']}\n"
457
+ f"Link: {article['link']}\n"
458
+ f"Published: {article['providerPublishTime']}\n\n")
459
+
460
+ logger.info("format_company_news_for_gemini took %.2f seconds", time.time() - start_time)
461
+ return formatted_news
462
+ except Exception as e:
463
+ logger.error(f"Error formatting company news for Gemini: {e}")
464
+ return "Error formatting company news."
465
+
466
+
467
+ def send_to_gemini_for_summarization(content: str) -> str:
468
+ start_time = time.time()
469
+ try:
470
+ unified_content = " ".join(content)
471
+ prompt = f"Summarize the main points of this article.\n\n{unified_content}"
472
+ response = invoke_llm(prompt)
473
+ logger.info("send_to_gemini_for_summarization took %.2f seconds", time.time() - start_time)
474
+ return response.content.strip()
475
+ except Exception as e:
476
+ logger.error(f"Error sending content to Gemini for summarization: {e}")
477
+ return "Error summarizing content."
478
+
479
+
480
+ def answer_question_with_data(question: str, data: Dict) -> str:
481
+ start_time = time.time()
482
+ try:
483
+ data_str = ""
484
+ for key, value in data.items():
485
+ data_str += f"{key}:\n{value}\n\n"
486
+
487
+ prompt = (f"You are a financial advisor. Begin your answer and only give the answer after.\n"
488
+ f"Using the following data, answer this question: {question}\n\nData:\n{data_str}\n"
489
+ f"Make your answer in the best form and professional.\n"
490
+ f"Don't say anything about the source of the data.\n"
491
+ f"If you don't have the data to answer, say this data is not available yet. If the data is not available in the stock history data, say this was a weekend and there is no data for it.")
492
+ response = invoke_llm(prompt)
493
+ logger.info("answer_question_with_data took %.2f seconds", time.time() - start_time)
494
+ return response.content.strip()
495
+ except Exception as e:
496
+ logger.error(f"Error answering question with data: {e}")
497
+ return "Error answering question."
498
+
499
+
500
+ def calculate_metrics(stock_data: pd.DataFrame, summarizer: DataSummarizer, company_info: Dict) -> Dict[str, str]:
501
+ start_time = time.time()
502
+ try:
503
+ moving_average = summarizer.calculate_moving_average(stock_data)
504
+ rsi = summarizer.calculate_rsi(stock_data)
505
+ ema = summarizer.calculate_ema(stock_data)
506
+ bollinger_bands = summarizer.calculate_bollinger_bands(stock_data)
507
+ macd = summarizer.calculate_macd(stock_data)
508
+ volatility = summarizer.calculate_volatility(stock_data)
509
+ atr = summarizer.calculate_atr(stock_data)
510
+ obv = summarizer.calculate_obv(stock_data)
511
+ yearly_summary = summarizer.calculate_yearly_summary(stock_data)
512
+ ytd_performance = summarizer.calculate_ytd_performance(stock_data)
513
+
514
+ eps = company_info.get('trailingEps', None)
515
+ if eps:
516
+ current_price = stock_data.iloc[-1]['close']
517
+ pe_ratio = summarizer.calculate_pe_ratio(current_price, eps)
518
+ formatted_metrics = {
519
+ "Moving Average": moving_average.to_string(),
520
+ "RSI": rsi.to_string(),
521
+ "EMA": ema.to_string(),
522
+ "Bollinger Bands": bollinger_bands.to_string(),
523
+ "MACD": macd.to_string(),
524
+ "Volatility": volatility.to_string(),
525
+ "ATR": atr.to_string(),
526
+ "OBV": obv.to_string(),
527
+ "Yearly Summary": yearly_summary.to_string(),
528
+ "YTD Performance": f"{ytd_performance:.2f}%",
529
+ "P/E Ratio": f"{pe_ratio:.2f}"
530
+ }
531
+ else:
532
+ formatted_metrics = {
533
+ "Moving Average": moving_average.to_string(),
534
+ "RSI": rsi.to_string(),
535
+ "EMA": ema.to_string(),
536
+ "Bollinger Bands": bollinger_bands.to_string(),
537
+ "MACD": macd.to_string(),
538
+ "Volatility": volatility.to_string(),
539
+ "ATR": atr.to_string(),
540
+ "OBV": obv.to_string(),
541
+ "Yearly Summary": yearly_summary.to_string(),
542
+ "YTD Performance": f"{ytd_performance:.2f}%"
543
+ }
544
+
545
+ logger.info("calculate_metrics took %.2f seconds", time.time() - start_time)
546
+ return formatted_metrics
547
+ except Exception as e:
548
+ logger.error(f"Error calculating metrics: {e}")
549
+ return {"Error": "Error calculating metrics"}
550
+
551
+
552
+ def prepare_data(formatted_stock_data: str, formatted_company_info: str, formatted_company_news: str,
553
+ google_results: str, formatted_metrics: Dict[str, str], google_snippet: str, rag_response: str) -> \
554
+ Dict[str, str]:
555
+ start_time = time.time()
556
+ collected_data = {
557
+ "Formatted Stock Data": formatted_stock_data,
558
+ "Formatted Company Info": formatted_company_info,
559
+ "Formatted Company News": formatted_company_news,
560
+ "Google Search Results": google_results,
561
+ "Google Snippet": google_snippet,
562
+ "RAG Response": rag_response,
563
+ "Calculations": formatted_metrics
564
+ }
565
+ collected_data.update(formatted_metrics)
566
+ logger.info("prepare_data took %.2f seconds", time.time() - start_time)
567
+ return collected_data
568
+
569
+
570
+ @app.route('/ask', methods=['POST'])
571
+ def ask():
572
+ try:
573
+ user_input = request.json.get('question')
574
+ logger.info(f"Received question: {user_input}")
575
+
576
+ summarizer = DataSummarizer()
577
+
578
+ # Detect language, entity, translation, and stock ticker
579
+ language, entity, translation, stock_ticker = detect_translate_entity_and_ticker(user_input)
580
+
581
+ logger.info(f"Detected Language: {language}, Entity: {entity}, Translation: {translation}, Stock Ticker: {stock_ticker}")
582
+
583
+ if entity and stock_ticker:
584
+ with ThreadPoolExecutor() as executor:
585
+ futures = {
586
+ executor.submit(fetch_stock_data_yahoo, stock_ticker): "stock_data",
587
+ executor.submit(fetch_company_info_yahoo, stock_ticker): "company_info",
588
+ executor.submit(fetch_company_news_yahoo, stock_ticker): "company_news",
589
+ executor.submit(fetch_current_stock_price, stock_ticker): "current_stock_price",
590
+ executor.submit(get_answer, user_input): "rag_response",
591
+ executor.submit(summarizer.google_search, user_input): "google_results",
592
+ executor.submit(summarizer.fetch_google_snippet, user_input): "google_snippet"
593
+ }
594
+ results = {futures[future]: future.result() for future in as_completed(futures)}
595
+
596
+ stock_data = results.get("stock_data", pd.DataFrame())
597
+ formatted_stock_data = format_stock_data_for_gemini(stock_data) if not stock_data.empty else "No historical data available."
598
+
599
+ company_info = results.get("company_info", {})
600
+ formatted_company_info = format_company_info_for_gemini(company_info) if company_info else "No company info available."
601
+
602
+ company_news = results.get("company_news", [])
603
+ formatted_company_news = format_company_news_for_gemini(company_news) if company_news else "No news available."
604
+
605
+ current_stock_price = results.get("current_stock_price", None)
606
+
607
+ formatted_metrics = calculate_metrics(stock_data, summarizer, company_info) if not stock_data.empty else {"Error": "No stock data for metrics"}
608
+
609
+ google_results = results.get("google_results", "No additional news found through Google Search.")
610
+ google_snippet = results.get("google_snippet", "Snippet not found.")
611
+ rag_response = results.get("rag_response", "No response from RAG.")
612
+
613
+ collected_data = prepare_data(formatted_stock_data, formatted_company_info, formatted_company_news,
614
+ google_results, formatted_metrics, google_snippet, rag_response)
615
+ collected_data["Current Stock Price"] = f"${current_stock_price:.2f}" if current_stock_price is not None else "N/A"
616
+
617
+ answer = answer_question_with_data(f"{translation}", collected_data)
618
+
619
+ return jsonify({"answer": answer})
620
+
621
+ else:
622
+ with ThreadPoolExecutor() as executor:
623
+ futures = {
624
+ executor.submit(get_answer, user_input): "rag_response",
625
+ executor.submit(summarizer.google_search, user_input): "google_results",
626
+ executor.submit(summarizer.fetch_google_snippet, user_input): "google_snippet"
627
+ }
628
+ results = {futures[future]: future.result() for future in as_completed(futures)}
629
+
630
+ google_results = results.get("google_results", "No additional news found through Google Search.")
631
+ google_snippet = results.get("google_snippet", "Snippet not found.")
632
+ rag_response = results.get("rag_response", "No response from RAG.")
633
+
634
+ collected_data = prepare_data("", "", "", google_results, {}, google_snippet, rag_response)
635
+
636
+ answer = answer_question_with_data(f"{user_input}", collected_data)
637
+
638
+ return jsonify({"answer": answer})
639
+
640
+ except Exception as e:
641
+ logger.error(f"An error occurred: {e}")
642
+ return jsonify({"error": "An error occurred while processing your request. Please try again later."}), 500
643
+
644
+
645
+ if __name__ == '__main__':
646
+ app.run(host='0.0.0.0', port=5000)
config.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+
4
+ # Load environment variables from .env file
5
+ load_dotenv()
6
+
7
+ class Config:
8
+ GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
9
+ ALPHA_VANTAGE_KEY = os.getenv("ALPHA_VANTAGE_KEY")
10
+ YAHOO_FINANCE_API_KEY = os.getenv("YAHOO_FINANCE_API_KEY")
11
+ FINNHUB_API_KEY = os.getenv("FINNHUB_API_KEY")
12
+ POLYGON_API_KEY = os.getenv("POLYGON_API_KEY")
13
+ SECRET_KEY = os.getenv("SECRET_KEY")
14
+ GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
15
+ SEARCH_ENGINE_ID = os.getenv("SEARCH_ENGINE_ID")
16
+ # Add any additional configuration variables here
17
+
18
+
embeddings.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import google.generativeai as genai
2
+ from dotenv import load_dotenv
3
+ import os
4
+ from langchain_google_genai import GoogleGenerativeAIEmbeddings
5
+ from langchain_cohere import CohereEmbeddings
6
+ from langchain_openai import OpenAIEmbeddings
7
+ from langchain.embeddings import HuggingFaceEmbeddings
8
+
9
+ load_dotenv()
10
+
11
+ class Embeddings:
12
+
13
+ '''
14
+ google, models/embedding-001
15
+ openai, openai
16
+ cohere, cohere
17
+ hf, all-MiniLM-L6-v2
18
+ hf, BAAI/bge-large-en-v1.5
19
+ hf, Alibaba-NLP/gte-large-en-v1.5, True
20
+ ...
21
+ ...
22
+ '''
23
+
24
+ def __init__(self, emb, model, trust_remote=False, normalize = False):
25
+ self.emb=emb
26
+ self.model = model
27
+ self.trust_remote = trust_remote
28
+ self.normalize = normalize
29
+ self.embedding = self.get_embedding()
30
+ self.seq_len = self.get_emb_len()
31
+
32
+ def get_emb_len(self):
33
+ return len(self.embedding.embed_query('hi how are you'))
34
+
35
+ def google_embedding(self):
36
+ genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
37
+ embeddings = GoogleGenerativeAIEmbeddings(model = self.model)
38
+ return embeddings
39
+
40
+ def openai_embedding(self):
41
+ embeddings_model = OpenAIEmbeddings(api_key=os.getenv("OPENAI_API_KEY"))
42
+ return embeddings_model
43
+
44
+ def cohere_embedding(self):
45
+ embeddings_model = CohereEmbeddings(cohere_api_key=os.getenv("COHERE_API_KEY"))
46
+ return embeddings_model
47
+
48
+ def hf_embedding(self):
49
+ model_args = {'trust_remote_code': True} if self.trust_remote else {}
50
+ encode_args = {'normalize_embeddings': True} if self.normalize else {}
51
+ embedding = HuggingFaceEmbeddings(model_name=self.model, model_kwargs = model_args, encode_kwargs = encode_args)
52
+ return embedding
53
+
54
+ def get_embedding(self):
55
+ if self.emb == 'google':
56
+ return self.google_embedding()
57
+ elif self.emb == 'openai':
58
+ return self.openai_embedding()
59
+ elif self.emb == 'cohere':
60
+ return self.cohere_embedding()
61
+ elif self.emb == 'hf':
62
+ return self.hf_embedding()
flasktest.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import requests
3
+ import json
4
+
5
+
6
+ def send_question_to_api(question):
7
+ url = 'http://localhost:5000/ask'
8
+ headers = {'Content-Type': 'application/json'}
9
+ data = {'question': question}
10
+
11
+ response = requests.post(url, headers=headers, data=json.dumps(data))
12
+
13
+ if response.status_code == 200:
14
+ return response.json().get('answer')
15
+ else:
16
+ return f"Error: {response.status_code} - {response.text}"
17
+
18
+
19
+ def main():
20
+ st.title("Financial Data Chatbot Tester")
21
+
22
+ st.write("Enter your question below and get a response from the chatbot.")
23
+
24
+ # Initialize session state to store question history
25
+ if 'history' not in st.session_state:
26
+ st.session_state.history = []
27
+
28
+ user_input = st.text_input("Your question:", "")
29
+
30
+ if st.button("Submit"):
31
+ if user_input:
32
+ with st.spinner('Getting the answer...'):
33
+ answer = send_question_to_api(user_input)
34
+ st.session_state.history.append((user_input, answer))
35
+ st.success(answer)
36
+ else:
37
+ st.warning("Please enter a question before submitting.")
38
+
39
+ # Display the history of questions and answers
40
+ if st.session_state.history:
41
+ st.write("### History")
42
+ for idx, (question, answer) in enumerate(st.session_state.history, 1):
43
+ st.write(f"**Q{idx}:** {question}")
44
+ st.write(f"**A{idx}:** {answer}")
45
+ st.write("---")
46
+
47
+
48
+ if __name__ == '__main__':
49
+ main()
index.html ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Chatbot</title>
7
+ <style>
8
+ body {
9
+ font-family: Arial, sans-serif;
10
+ background-color: #f4f4f9;
11
+ margin: 40px;
12
+ text-align: center;
13
+ }
14
+ input[type="text"] {
15
+ width: 300px;
16
+ padding: 10px;
17
+ font-size: 16px;
18
+ margin-top: 20px;
19
+ border: 2px solid #ccc;
20
+ border-radius: 5px;
21
+ }
22
+ button {
23
+ background-color: #4CAF50;
24
+ color: white;
25
+ padding: 10px 20px;
26
+ margin-top: 10px;
27
+ border: none;
28
+ border-radius: 5px;
29
+ cursor: pointer;
30
+ font-size: 16px;
31
+ }
32
+ button:hover {
33
+ background-color: #45a049;
34
+ }
35
+ p {
36
+ margin-top: 20px;
37
+ font-size: 18px;
38
+ color: #333;
39
+ }
40
+ </style>
41
+ </head>
42
+ <body>
43
+ <h1>Chatbot Interface</h1>
44
+ <input type="text" id="question" placeholder="Ask a question...">
45
+ <button onclick="askQuestion()">Ask</button>
46
+ <p id="answer">Answer will appear here...</p>
47
+
48
+ <script>
49
+ async function askQuestion() {
50
+ const questionInput = document.getElementById('question');
51
+ const answerDisplay = document.getElementById('answer');
52
+ const question = questionInput.value;
53
+
54
+ const response = await fetch('/chat/', {
55
+ method: 'POST',
56
+ headers: {
57
+ 'Content-Type': 'application/json'
58
+ },
59
+ body: JSON.stringify({ question: question })
60
+ });
61
+ if (response.ok) {
62
+ const data = await response.json();
63
+ answerDisplay.textContent = 'Answer: ' + data.answer;
64
+ } else {
65
+ answerDisplay.textContent = 'Error: Unable to fetch answer.';
66
+ }
67
+ }
68
+ </script>
69
+ </body>
70
+ </html>
llm.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_google_genai import ChatGoogleGenerativeAI
2
+ import google.generativeai as genai
3
+ from langchain.chat_models import ChatOpenAI
4
+ from langchain_groq import ChatGroq
5
+ import vertexai
6
+ from langchain_google_vertexai import ChatVertexAI
7
+
8
+ from dotenv import load_dotenv
9
+ import os
10
+
11
+ load_dotenv()
12
+
13
+ genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
14
+
15
+ class LLM:
16
+ def __init__(self, llm, model=None):
17
+ if llm == 'gemini':
18
+ if model is None:
19
+ model = "gemini-pro"
20
+ self.llm = ChatGoogleGenerativeAI(model=model, temperature=0.3)
21
+ elif llm == 'vertex':
22
+ vertexai.init(project="website-254017", location="us-central1")
23
+ if model is None:
24
+ model = "gemini-1.5-pro-preview-0514"
25
+ self.llm = ChatVertexAI(model_name=model, temperature=0, max_tokens=8000)
26
+ elif llm == 'openai':
27
+ if model is None:
28
+ model = 'gpt-3.5-turbo-0125'
29
+ # ChatOpenAI(temperature=0, model="gpt-3.5-turbo-0125")
30
+ self.llm = ChatOpenAI(api_key=os.getenv("OPENAI_API_KEY"), model=model)
31
+
32
+ elif llm == 'mixtral':
33
+ model = "mixtral-8x7b-32768"
34
+ self.llm = ChatGroq(temperature=0, groq_api_key=os.getenv("GROK_API_KEY"), model_name=model)
35
+
36
+ elif llm == 'llama':
37
+ if model is None:
38
+ model = 'llama3-8b-8192'
39
+ self.llm = ChatGroq(temperature=0, groq_api_key=os.getenv("GROK_API_KEY"), model_name=model)
40
+
41
+ def get_llm(self):
42
+ return self.llm
43
+
44
+
45
+
logging_config.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging.config
2
+
3
+ def setup_logging():
4
+ logging_config = {
5
+ 'version': 1,
6
+ 'disable_existing_loggers': False,
7
+ 'formatters': {
8
+ 'standard': {
9
+ 'format': '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
10
+ },
11
+ },
12
+ 'handlers': {
13
+ 'console': {
14
+ 'level': 'DEBUG',
15
+ 'class': 'logging.StreamHandler',
16
+ 'formatter': 'standard',
17
+ },
18
+ 'file': {
19
+ 'level': 'DEBUG',
20
+ 'class': 'logging.FileHandler',
21
+ 'filename': 'financial_adviser.log',
22
+ 'formatter': 'standard',
23
+ },
24
+ },
25
+ 'loggers': {
26
+ '': {
27
+ 'handlers': ['console', 'file'],
28
+ 'level': 'DEBUG',
29
+ 'propagate': True,
30
+ },
31
+ },
32
+ }
33
+
34
+ logging.config.dictConfig(logging_config)
35
+
36
+ # Initialize the logger
37
+ setup_logging()
38
+ logger = logging.getLogger(__name__)
main.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ from langchain.chains.conversation.memory import ConversationBufferWindowMemory
4
+ from data_extraction import Extraction
5
+ import nest_asyncio
6
+ from chunking import Chunker
7
+ from embeddings import Embeddings
8
+ from vectorstore import VectorDB
9
+ from retriever import Retriever, CreateBM25Retriever
10
+ from llm import LLM
11
+ from langchain_core.prompts import PromptTemplate
12
+ from chain import Chain
13
+ from streamlit_chat import message
14
+
15
+ if 'responses' not in st.session_state:
16
+ st.session_state['responses'] = ["How can I assist you?"]
17
+
18
+ if 'requests' not in st.session_state:
19
+ st.session_state['requests'] = []
20
+
21
+ if 'buffer_memory' not in st.session_state:
22
+ st.session_state.buffer_memory=ConversationBufferWindowMemory(k=3,return_messages=True)
23
+
24
+ nest_asyncio.apply()
25
+ ext = Extraction('fast')
26
+ chnk = Chunker(chunk_size=1000, chunk_overlap=200)
27
+ emb = Embeddings("hf", "all-MiniLM-L6-v2")
28
+ _llm = LLM('vertex').get_llm()
29
+ ch = Chain(_llm, st.session_state.buffer_memory)
30
+ conversation = ch.get_chain_with_history()
31
+
32
+ def query_refiner(conversation, query):
33
+ prompt=f"Given the following user query and historical user queries, rephrase the users current query to form a meaningful and clear question.Previously user has asked the following: \n{conversation}\n\n User's Current Query: {query}. What will be the refined query? Only provide the query without any extra details or explanations.",
34
+ ans = _llm.invoke(prompt).content
35
+ return ans
36
+
37
+ def get_conversation_string():
38
+ conversation_string = ""
39
+ for i in range(len(st.session_state['responses'])-1):
40
+ conversation_string += "Human: "+st.session_state['requests'][i] + "\n"
41
+ # conversation_string += "Bot: "+ st.session_state['responses'][i+1] + "\n"
42
+ return conversation_string
43
+
44
+ def main():
45
+ inp_dir = "./inputs"
46
+ db = 'pinecone'
47
+ db_dir = 'pineconedb'
48
+ st.set_page_config("Chat PDF")
49
+ st.header("Chat with PDF")
50
+
51
+ response_container = st.container()
52
+ textcontainer = st.container()
53
+ ret = None
54
+ with textcontainer:
55
+ query = st.text_input("Query: ", key="input")
56
+ if query:
57
+ if ret is None:
58
+ ret = Retriever(db, db_dir, emb.embedding, 'ensemble', 5)
59
+ with st.spinner("typing..."):
60
+ conversation_string = get_conversation_string()
61
+ if len(st.session_state['responses']) != 0:
62
+ refined_query = query_refiner(conversation_string, query)
63
+ else:
64
+ refined_query = query
65
+ st.subheader("Refined Query:")
66
+ st.write(refined_query)
67
+ context, context_list = ret.get_context(refined_query)
68
+ response = conversation.predict(input=f"Context:\n {context} \n\n Query:\n{query}")
69
+ # response += '\n' + "Source: " + src
70
+ st.session_state.requests.append(query)
71
+ st.session_state.responses.append(response)
72
+
73
+ with response_container:
74
+ if st.session_state['responses']:
75
+ for i in range(len(st.session_state['responses'])):
76
+ message(st.session_state['responses'][i],key=str(i))
77
+ if i < len(st.session_state['requests']):
78
+ message(st.session_state["requests"][i], is_user=True,key=str(i)+ '_user')
79
+
80
+ with st.sidebar:
81
+ st.title("Menu:")
82
+ pdf_docs = st.file_uploader("Upload your PDF Files and Click on the Submit & Process Button", accept_multiple_files=True)
83
+ pdfs = []
84
+ if pdf_docs:
85
+ for pdf_file in pdf_docs:
86
+ filename = pdf_file.name
87
+ path = os.path.join(inp_dir,filename)
88
+ with open(path, "wb") as f:
89
+ f.write(pdf_file.getvalue())
90
+ pdfs.append(path)
91
+
92
+ with st.spinner("Processing..."):
93
+ texts, metas = ext.get_text(pdfs)
94
+ docs = chnk.get_chunks(texts, metas)
95
+ vs = VectorDB(db, emb.embedding, db_dir, docs=docs)
96
+ bm = CreateBM25Retriever(docs)
97
+ st.success("Done")
98
+
99
+ if __name__ == "__main__":
100
+ main()
rag.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from embeddings import Embeddings
4
+ from chain import Chain
5
+ from llm import LLM
6
+ from retriever import Retriever
7
+ from fastapi import FastAPI, HTTPException
8
+ from fastapi.responses import HTMLResponse
9
+ from functools import lru_cache
10
+ from tools import *
11
+ import re
12
+
13
+ emb = Embeddings("hf", "all-MiniLM-L6-v2")
14
+ llm = LLM('gemini').get_llm()
15
+ ch = Chain(llm,None)
16
+ ret = Retriever('pinecone', 'pinecone', emb.embedding, 'ensemble', 5)
17
+
18
+ is_arabic = False
19
+
20
+ @lru_cache()
21
+ def investment_banker(query):
22
+ global is_arabic
23
+ context, context_list = ret.get_context(query)
24
+ if not is_arabic:
25
+ prompt_template = f"""
26
+ You are an investment banker and financial advisor.
27
+ Answer the question as detailed as possible from the provided context and make sure to provide all the details.
28
+ Answer only from the context. If the answer is not in provided context, say "Answer not in context".\n\n
29
+ Context:\n {context}\n\n
30
+ Question: \n{query}\n
31
+
32
+ Answer:
33
+ """
34
+ else:
35
+ prompt_template = f"""
36
+ You are an investment banker and financial advisor.
37
+ Answer the question as detailed as possible from the provided context and make sure to provide all the details.
38
+ Answer only from the context. If the answer is not in provided context, say "Answer not in context".
39
+ Return the answer in Arabic only.\n\n
40
+ Context:\n {context}\n\n
41
+ Question: \n{query}\n
42
+
43
+ Answer:
44
+ """
45
+ response = ch.run_conversational_chain(prompt_template)
46
+ is_arabic = False
47
+ return response
48
+
49
+ def check_arabic(s):
50
+ arabic_pattern = re.compile(r'[\u0600-\u06FF]')
51
+ if arabic_pattern.search(s):
52
+ return True
53
+ else:
54
+ return False
55
+
56
+ history = ""
57
+
58
+ @lru_cache()
59
+ def refine_query(query, conversation):
60
+ prompt=f"""Given the following user query and historical user conversation with banker.
61
+ If the current user query is in arabic, convert it to english and then proceed.
62
+ If conversation history is empty return the current query as it is.
63
+ If the query is a continuation of previous conversation then only rephrase the users current query to form a meaningful and clear question.
64
+ Otherwise return the user query as it is.
65
+ Previously user and banker had the following conversation: \n{conversation}\n\n User's Current Query: {query}.
66
+ What will be the refined query? Only provide the query without any extra details or explanations."""
67
+ ans = llm.invoke(prompt).content
68
+ return ans
69
+
70
+
71
+ def get_answer(query):
72
+ global history
73
+ global is_arabic
74
+
75
+ is_arabic = check_arabic(query)
76
+ ref_query = refine_query(query, history)
77
+ ans = investment_banker(ref_query)
78
+ history += "Human: "+ ref_query + "\n"
79
+ history += "Banker: "+ ans + "\n"
80
+
81
+ return ans
82
+ if __name__ == "__main__":
83
+ response = get_answer()
84
+ print(response)
85
+ # app = FastAPI()
86
+
87
+ # class Query(BaseModel):
88
+ # question: str
89
+
90
+ # @app.post("/chat/")
91
+ # async def chat(query: Query):
92
+ # global history
93
+ # global is_arabic
94
+
95
+ # try:
96
+
97
+ # is_arabic = check_arabic(query.question)
98
+ # ref_query = refine_query(query.question, history)
99
+
100
+
101
+ # print(query.question, ref_query)
102
+ # print(is_arabic)
103
+ # ans = investment_banker(ref_query)
104
+ # history += "Human: "+ ref_query + "\n"
105
+ # history += "Banker: "+ ans + "\n"
106
+ # return {"question": query.question, "answer": ans}
107
+ # except Exception as e:
108
+ # raise HTTPException(status_code=500, detail=str(e))
109
+
110
+
111
+ # @app.get("/", response_class=HTMLResponse)
112
+ # async def read_index():
113
+ # with open('index.html', 'r') as f:
114
+ # return HTMLResponse(content=f.read())
requirements.txt ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pydantic
2
+ langchain
3
+ yfinance
4
+ langchain_google_genai
5
+ langchain_openai
6
+ langchain_cohere
7
+ google-generativeai
8
+ langchain_groq
9
+ python-dotenv
10
+ vertexai
11
+ langchain_pinecone
12
+ qdrant_client
13
+ uvicorn
14
+ langchain-community
15
+ langchain_google_vertexai
16
+ sentence-transformers
17
+ rank_bm25
18
+ matplotlib
19
+ pandas
20
+ numpy
21
+ requests
22
+ spacy
23
+ transformers
24
+ torch
25
+ sentencepiece
26
+ streamlit
27
+ flask
28
+ bs4
29
+ tenacity
30
+ loguru
retriever.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.retrievers import BM25Retriever, EnsembleRetriever
2
+ from langchain.vectorstores import FAISS, Chroma, Qdrant
3
+ from qdrant_client import QdrantClient
4
+ from langchain_pinecone import PineconeVectorStore
5
+ import os
6
+ from dotenv import load_dotenv
7
+ import pickle
8
+
9
+ load_dotenv()
10
+
11
+ class CreateBM25Retriever:
12
+ def __init__(self, docs):
13
+ self.bm25_retriever = BM25Retriever.from_documents(docs)
14
+ with open('bm25retriever.pkl', 'wb') as outp:
15
+ pickle.dump(self.bm25_retriever, outp, pickle.HIGHEST_PROTOCOL)
16
+
17
+ class Retriever:
18
+ def __init__(self, db,per_dir,embeddings, strategy, k, collection_name="mydocuments"):
19
+ self.db = db
20
+ self.strategy = strategy
21
+ self.per_dir = per_dir
22
+ if self.db == 'faiss':
23
+ self.db_ = FAISS.load_local(self.per_dir, embeddings, allow_dangerous_deserialization=True)
24
+ elif self.db == 'chroma':
25
+ self.db_ = Chroma(persist_directory=self.per_dir, embedding_function=embeddings)
26
+ elif self.db == 'qdrant':
27
+ self.db_ = Qdrant(client=QdrantClient(path=self.per_dir), collection_name=collection_name, embeddings=embeddings)
28
+ elif self.db == 'pinecone':
29
+ self.db_ = PineconeVectorStore(pinecone_api_key=os.getenv("PINECONE_API_KEY"),index_name=collection_name, embedding=embeddings)
30
+ self.retriever = self.db_.as_retriever(search_kwargs={"k": k})
31
+
32
+ if strategy == 'ensemble':
33
+ with open('bm25retriever.pkl', 'rb') as inp:
34
+ self.bm25_retriever = pickle.load(inp)
35
+ self.bm25_retriever.k = k
36
+ self.retriever = EnsembleRetriever(retrievers=[self.bm25_retriever, self.retriever],
37
+ weights=[0.4, 0.6])
38
+
39
+ def get_docs(self, query):
40
+ return self.retriever.get_relevant_documents(query)
41
+
42
+ def get_context(self, query):
43
+ docs = self.get_docs(query)
44
+ context = ""
45
+ context_list = []
46
+ # src = []
47
+ for txt in docs:
48
+ context += '\n\n'+txt.page_content + "\n" + "Source: "+txt.metadata['source']
49
+ context_list.append(txt.page_content)
50
+ # src.append(txt.metadata['source'])
51
+ # src = max(set(src), key=src.count)
52
+ return context, context_list
53
+
tools.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+ from datetime import datetime, timedelta
3
+ import yfinance as yf
4
+ from langchain.prompts import MessagesPlaceholder, ChatPromptTemplate, HumanMessagePromptTemplate, PromptTemplate, AIMessagePromptTemplate
5
+ from pydantic import BaseModel, Field
6
+ from langchain.tools import BaseTool
7
+ from typing import Optional, Type
8
+ from typing import List
9
+ from functools import lru_cache
10
+
11
+
12
+ @lru_cache()
13
+ def get_stock_price(symbol):
14
+ ticker = yf.Ticker(symbol)
15
+ todays_data = ticker.history(period='1d')
16
+ price = round(todays_data['Close'][0], 2)
17
+ currency = ticker.info['currency']
18
+ return price, currency
19
+
20
+ @lru_cache()
21
+ def get_stock_data_yahoo(ticker):
22
+ stock = yf.Ticker(ticker)
23
+ data = stock.history(period="1y")
24
+ return data
25
+
26
+ @lru_cache()
27
+ def get_company_profile_yahoo(ticker):
28
+ stock = yf.Ticker(ticker)
29
+ info = stock.info
30
+ profile = {
31
+ "name": info.get("shortName"),
32
+ "sector": info.get("sector"),
33
+ "industry": info.get("industry"),
34
+ "marketCap": info.get("marketCap"),
35
+ "website": info.get("website"),
36
+ "description": info.get("longBusinessSummary"),
37
+ }
38
+ return profile
39
+
40
+ @lru_cache()
41
+ def get_company_news_yahoo(ticker):
42
+ stock = yf.Ticker(ticker)
43
+ news = stock.news
44
+ return news
45
+
46
+ @lru_cache()
47
+ def get_price_change_percent(symbol, days_ago):
48
+ ticker = yf.Ticker(symbol)
49
+ end_date = datetime.now()
50
+ start_date = end_date - timedelta(days=days_ago)
51
+
52
+ # Convert dates to string format that yfinance can accept
53
+ start_date = start_date.strftime('%Y-%m-%d')
54
+ end_date = end_date.strftime('%Y-%m-%d')
55
+
56
+ historical_data = ticker.history(start=start_date, end=end_date)
57
+
58
+ old_price = historical_data['Close'].iloc[0]
59
+ new_price = historical_data['Close'].iloc[-1]
60
+
61
+ percent_change = ((new_price - old_price) / old_price) * 100
62
+ return round(percent_change, 2)
63
+
64
+ @lru_cache()
65
+ def calculate_performance(symbol, days_ago):
66
+ ticker = yf.Ticker(symbol)
67
+ end_date = datetime.now()
68
+ start_date = end_date - timedelta(days=days_ago)
69
+ start_date = start_date.strftime('%Y-%m-%d')
70
+ end_date = end_date.strftime('%Y-%m-%d')
71
+ historical_data = ticker.history(start=start_date, end=end_date)
72
+ old_price = historical_data['Close'].iloc[0]
73
+ new_price = historical_data['Close'].iloc[-1]
74
+ percent_change = ((new_price - old_price) / old_price) * 100
75
+ return round(percent_change, 2)
76
+
77
+ @lru_cache()
78
+ def get_best_performing(stocks, days_ago):
79
+ best_stock = None
80
+ best_performance = None
81
+ for stock in stocks:
82
+ try:
83
+ performance = calculate_performance(stock, days_ago)
84
+ if best_performance is None or performance > best_performance:
85
+ best_stock = stock
86
+ best_performance = performance
87
+ except Exception as e:
88
+ print(f"Could not calculate performance for {stock}: {e}")
89
+ return best_stock, best_performance
90
+
91
+ class StockPriceCheckInput(BaseModel):
92
+ """Input for Stock price check."""
93
+
94
+ stockticker: str = Field(..., description="Ticker symbol for stock or index")
95
+
96
+ class StockPriceTool(BaseTool):
97
+ name = "get_stock_ticker_price"
98
+ description = "Useful for when you need to find out the price of the stock today. You should input the stock ticker used on the yfinance API"
99
+
100
+ def _run(self, stockticker: str):
101
+ # print("i'm running")
102
+ price_response, currency = get_stock_price(stockticker)
103
+
104
+ return f"{currency} {price_response}"
105
+
106
+ def _arun(self, stockticker: str):
107
+ raise NotImplementedError("This tool does not support async")
108
+
109
+ args_schema: Optional[Type[BaseModel]] = StockPriceCheckInput
110
+
111
+ class PrevYearStockTool(BaseTool):
112
+ name = "get_past_year_stock_data"
113
+ description = "Useful for when you need to find out the past 1 year performance of a stock. You should input the stock ticker used on the yfinance API"
114
+
115
+ def _run(self, stockticker: str):
116
+ price_response = get_stock_data_yahoo(stockticker)
117
+ return price_response
118
+
119
+ def _arun(self, stockticker: str):
120
+ raise NotImplementedError("This tool does not support async")
121
+
122
+ args_schema: Optional[Type[BaseModel]] = StockPriceCheckInput
123
+
124
+ class StockNewsTool(BaseTool):
125
+ name = "get_news_about_stock"
126
+ description = "Useful for when you need recent news related to a stock. You should input the stock ticker used on the yfinance API"
127
+
128
+ def _run(self, stockticker: str):
129
+ price_response = get_company_news_yahoo(stockticker)
130
+ return price_response
131
+
132
+ def _arun(self, stockticker: str):
133
+ raise NotImplementedError("This tool does not support async")
134
+
135
+ args_schema: Optional[Type[BaseModel]] = StockPriceCheckInput
136
+
137
+ class StockProfileTool(BaseTool):
138
+ name = "get_profile_of_stock"
139
+ description = "Useful for when you need details or profile of a stock. You should input the stock ticker used on the yfinance API"
140
+
141
+ def _run(self, stockticker: str):
142
+ price_response = get_company_profile_yahoo(stockticker)
143
+ return price_response
144
+
145
+ def _arun(self, stockticker: str):
146
+ raise NotImplementedError("This tool does not support async")
147
+
148
+ args_schema: Optional[Type[BaseModel]] = StockPriceCheckInput
149
+
150
+ class StockChangePercentageCheckInput(BaseModel):
151
+ """Input for Stock ticker check. for percentage check"""
152
+
153
+ stockticker: str = Field(..., description="Ticker symbol for stock or index")
154
+ days_ago: int = Field(..., description="Int number of days to look back")
155
+
156
+ class StockPercentageChangeTool(BaseTool):
157
+ name = "get_price_change_percent"
158
+ description = "Useful for when you need to find out the performance or percentage change in a stock's value. You should input the stock ticker used on the yfinance API and also input the number of days to check the change over"
159
+
160
+ def _run(self, stockticker: str, days_ago: int):
161
+ price_change_response = get_price_change_percent(stockticker, days_ago)
162
+
163
+ return price_change_response
164
+
165
+ def _arun(self, stockticker: str, days_ago: int):
166
+ raise NotImplementedError("This tool does not support async")
167
+
168
+ args_schema: Optional[Type[BaseModel]] = StockChangePercentageCheckInput
169
+
170
+ class StockBestPerformingInput(BaseModel):
171
+ """Input for Stock ticker check. for percentage check"""
172
+
173
+ stocktickers: List[str] = Field(..., description="Ticker symbols for stocks or indices")
174
+ days_ago: int = Field(..., description="Int number of days to look back")
175
+
176
+ class StockGetBestPerformingTool(BaseTool):
177
+ name = "get_best_performing"
178
+ description = "Useful for when you need to the performance of multiple stocks over a period. You should input a list of stock tickers used on the yfinance API and also input the number of days to check the change over"
179
+
180
+ def _run(self, stocktickers: List[str], days_ago: int):
181
+ price_change_response = get_best_performing(stocktickers, days_ago)
182
+
183
+ return price_change_response
184
+
185
+ def _arun(self, stockticker: List[str], days_ago: int):
186
+ raise NotImplementedError("This tool does not support async")
187
+
188
+ args_schema: Optional[Type[BaseModel]] = StockBestPerformingInput