Spaces:
Runtime error
Runtime error
import requests | |
from bs4 import BeautifulSoup | |
import yfinance as yf | |
import pandas as pd | |
from datetime import datetime, timedelta | |
import logging | |
from concurrent.futures import ThreadPoolExecutor, as_completed | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
from config import Config | |
import numpy as np | |
from typing import Optional, Tuple, List, Dict | |
from rag import get_answer | |
# Set up logging | |
logging.basicConfig(level=logging.DEBUG, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
handlers=[logging.FileHandler("app.log"), | |
logging.StreamHandler()]) | |
logger = logging.getLogger(__name__) | |
# Initialize the Gemini model | |
llm = ChatGoogleGenerativeAI(api_key=Config.GEMINI_API_KEY, model="gemini-1.5-flash-latest", temperature=0.5) | |
# Configuration for Google Custom Search API | |
GOOGLE_API_KEY = Config.GOOGLE_API_KEY | |
SEARCH_ENGINE_ID = Config.SEARCH_ENGINE_ID | |
def fetch_google_snippet(query: str) -> Optional[str]: | |
try: | |
search_url = f"https://www.google.com/search?q={query}" | |
headers = { | |
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.36" | |
} | |
response = requests.get(search_url, headers=headers) | |
soup = BeautifulSoup(response.text, 'html.parser') | |
snippet_classes = [ | |
'BNeawe iBp4i AP7Wnd', | |
'BNeawe s3v9rd AP7Wnd', | |
'BVG0Nb', | |
'kno-rdesc' | |
] | |
for cls in snippet_classes: | |
snippet = soup.find('div', class_=cls) | |
if snippet: | |
return snippet.get_text() | |
return "Snippet not found." | |
except Exception as e: | |
logger.error(f"Error fetching Google snippet: {e}") | |
return None | |
class DataSummarizer: | |
def __init__(self): | |
pass | |
def google_search(self, query: str) -> Optional[Dict]: | |
try: | |
url = "https://www.googleapis.com/customsearch/v1" | |
params = { | |
'key': GOOGLE_API_KEY, | |
'cx': SEARCH_ENGINE_ID, | |
'q': query | |
} | |
response = requests.get(url, params=params) | |
response.raise_for_status() | |
return response.json() | |
except Exception as e: | |
logger.error(f"Error during Google Search API request: {e}") | |
return None | |
def extract_content_from_item(self, item: Dict) -> Optional[str]: | |
try: | |
snippet = item.get('snippet', '') | |
title = item.get('title', '') | |
return f"{title}\n{snippet}" | |
except Exception as e: | |
logger.error(f"Error extracting content from item: {e}") | |
return None | |
def calculate_moving_average(self, df: pd.DataFrame, window: int = 20) -> Optional[pd.Series]: | |
try: | |
return df['close'].rolling(window=window).mean() | |
except Exception as e: | |
logger.error(f"Error calculating moving average: {e}") | |
return None | |
def calculate_rsi(self, df: pd.DataFrame, window: int = 14) -> Optional[pd.Series]: | |
try: | |
delta = df['close'].diff() | |
gain = delta.where(delta > 0, 0).rolling(window=window).mean() | |
loss = -delta.where(delta < 0, 0).rolling(window=window).mean() | |
rs = gain / loss | |
return 100 - (100 / (1 + rs)) | |
except Exception as e: | |
logger.error(f"Error calculating RSI: {e}") | |
return None | |
def calculate_ema(self, df: pd.DataFrame, window: int = 20) -> Optional[pd.Series]: | |
try: | |
return df['close'].ewm(span=window, adjust=False).mean() | |
except Exception as e: | |
logger.error(f"Error calculating EMA: {e}") | |
return None | |
def calculate_bollinger_bands(self, df: pd.DataFrame, window: int = 20) -> Optional[pd.DataFrame]: | |
try: | |
ma = df['close'].rolling(window=window).mean() | |
std = df['close'].rolling(window=window).std() | |
upper_band = ma + (std * 2) | |
lower_band = ma - (std * 2) | |
return pd.DataFrame({'MA': ma, 'Upper Band': upper_band, 'Lower Band': lower_band}) | |
except Exception as e: | |
logger.error(f"Error calculating Bollinger Bands: {e}") | |
return None | |
def calculate_macd(self, df: pd.DataFrame, short_window: int = 12, long_window: int = 26, signal_window: int = 9) -> \ | |
Optional[pd.DataFrame]: | |
try: | |
short_ema = df['close'].ewm(span=short_window, adjust=False).mean() | |
long_ema = df['close'].ewm(span=long_window, adjust=False).mean() | |
macd = short_ema - long_ema | |
signal = macd.ewm(span=signal_window, adjust=False).mean() | |
return pd.DataFrame({'MACD': macd, 'Signal Line': signal}) | |
except Exception as e: | |
logger.error(f"Error calculating MACD: {e}") | |
return None | |
def calculate_volatility(self, df: pd.DataFrame, window: int = 20) -> Optional[pd.Series]: | |
try: | |
log_returns = np.log(df['close'] / df['close'].shift(1)) | |
return log_returns.rolling(window=window).std() * np.sqrt(window) | |
except Exception as e: | |
logger.error(f"Error calculating volatility: {e}") | |
return None | |
def calculate_atr(self, df: pd.DataFrame, window: int = 14) -> Optional[pd.Series]: | |
try: | |
high_low = df['high'] - df['low'] | |
high_close = np.abs(df['high'] - df['close'].shift()) | |
low_close = np.abs(df['low'] - df['close'].shift()) | |
true_range = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1) | |
return true_range.rolling(window=window).mean() | |
except Exception as e: | |
logger.error(f"Error calculating ATR: {e}") | |
return None | |
def calculate_obv(self, df: pd.DataFrame) -> Optional[pd.Series]: | |
try: | |
return (np.sign(df['close'].diff()) * df['volume']).fillna(0).cumsum() | |
except Exception as e: | |
logger.error(f"Error calculating OBV: {e}") | |
return None | |
def calculate_yearly_summary(self, df: pd.DataFrame) -> Optional[pd.DataFrame]: | |
try: | |
df['year'] = pd.to_datetime(df['date']).dt.year | |
yearly_summary = df.groupby('year').agg({ | |
'close': ['mean', 'max', 'min'], | |
'volume': 'sum' | |
}) | |
yearly_summary.columns = ['_'.join(col) for col in yearly_summary.columns] | |
return yearly_summary | |
except Exception as e: | |
logger.error(f"Error calculating yearly summary: {e}") | |
return None | |
def get_full_last_year(self, df: pd.DataFrame) -> Optional[pd.DataFrame]: | |
try: | |
today = datetime.today().date() | |
last_year_start = datetime(today.year - 1, 1, 1).date() | |
last_year_end = datetime(today.year - 1, 12, 31).date() | |
mask = (df['date'] >= last_year_start) & (df['date'] <= last_year_end) | |
return df.loc[mask] | |
except Exception as e: | |
logger.error(f"Error filtering data for the last year: {e}") | |
return None | |
def calculate_ytd_performance(self, df: pd.DataFrame) -> Optional[float]: | |
try: | |
today = datetime.today().date() | |
year_start = datetime(today.year, 1, 1).date() | |
mask = (df['date'] >= year_start) & (df['date'] <= today) | |
ytd_data = df.loc[mask] | |
opening_price = ytd_data.iloc[0]['open'] | |
closing_price = ytd_data.iloc[-1]['close'] | |
return ((closing_price - opening_price) / opening_price) * 100 | |
except Exception as e: | |
logger.error(f"Error calculating YTD performance: {e}") | |
return None | |
def calculate_pe_ratio(self, current_price: float, eps: float) -> Optional[float]: | |
try: | |
if eps == 0: | |
raise ValueError("EPS cannot be zero for P/E ratio calculation.") | |
return current_price / eps | |
except Exception as e: | |
logger.error(f"Error calculating P/E ratio: {e}") | |
return None | |
def fetch_google_snippet(self, query: str) -> Optional[str]: | |
try: | |
return fetch_google_snippet(query) | |
except Exception as e: | |
logger.error(f"Error fetching Google snippet: {e}") | |
return None | |
def extract_ticker_from_response(response: str) -> Optional[str]: | |
try: | |
if "is **" in response and "**." in response: | |
return response.split("is **")[1].split("**.")[0].strip() | |
return response.strip() | |
except Exception as e: | |
logger.error(f"Error extracting ticker from response: {e}") | |
return None | |
def detect_translate_entity_and_ticker(query: str) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]: | |
try: | |
prompt = f"Detect the language for the following text: {query}" | |
response = llm.invoke(prompt) | |
detected_language = response.content.strip() | |
translated_query = query | |
if detected_language != "English": | |
prompt = f"Translate the following text to English: {query}" | |
response = llm.invoke(prompt) | |
translated_query = response.content.strip() | |
prompt = f"Detect the entity in the following text that is a company name: {translated_query}" | |
response = llm.invoke(prompt) | |
detected_entity = response.content.strip() | |
prompt = f"What is the stock ticker symbol for the company {detected_entity}?" | |
response = llm.invoke(prompt) | |
stock_ticker = extract_ticker_from_response(response.content.strip()) | |
return detected_language, detected_entity, translated_query, stock_ticker | |
except Exception as e: | |
logger.error(f"Error in detecting, translating, or extracting entity and ticker: {e}") | |
return None, None, None, None | |
def fetch_stock_data_yahoo(symbol: str) -> pd.DataFrame: | |
try: | |
stock = yf.Ticker(symbol) | |
logger.info(f"Fetching data for symbol: {symbol}") | |
end_date = datetime.now() | |
start_date = end_date - timedelta(days=3 * 365) | |
historical_data = stock.history(start=start_date, end=end_date) | |
if historical_data.empty: | |
raise ValueError(f"No historical data found for symbol: {symbol}") | |
historical_data = historical_data.rename( | |
columns={"Open": "open", "High": "high", "Low": "low", "Close": "close", "Volume": "volume"} | |
) | |
historical_data.reset_index(inplace=True) | |
historical_data['date'] = historical_data['Date'].dt.date | |
historical_data = historical_data.drop(columns=['Date']) | |
historical_data = historical_data[['date', 'open', 'high', 'low', 'close', 'volume']] | |
if 'close' not in historical_data.columns: | |
raise KeyError("The historical data must contain a 'close' column.") | |
return historical_data | |
except Exception as e: | |
logger.error(f"Failed to fetch stock data for {symbol} from Yahoo Finance: {e}") | |
return pd.DataFrame() | |
def fetch_current_stock_price(symbol: str) -> Optional[float]: | |
try: | |
stock = yf.Ticker(symbol) | |
return stock.info['currentPrice'] | |
except Exception as e: | |
logger.error(f"Failed to fetch current stock price for {symbol}: {e}") | |
return None | |
def format_stock_data_for_gemini(stock_data: pd.DataFrame) -> str: | |
try: | |
if stock_data.empty: | |
return "No historical data available." | |
formatted_data = "Historical stock data for the last three years:\n\n" | |
formatted_data += "Date | Open | High | Low | Close | Volume\n" | |
formatted_data += "------------------------------------------------------\n" | |
for index, row in stock_data.iterrows(): | |
formatted_data += f"{row['date']} | {row['open']:.2f} | {row['high']:.2f} | {row['low']:.2f} | {row['close']:.2f} | {int(row['volume'])}\n" | |
return formatted_data | |
except Exception as e: | |
logger.error(f"Error formatting stock data for Gemini: {e}") | |
return "Error formatting stock data." | |
def fetch_company_info_yahoo(symbol: str) -> Dict: | |
try: | |
if not symbol: | |
return {"error": "Invalid symbol"} | |
stock = yf.Ticker(symbol) | |
company_info = stock.info | |
return { | |
"name": company_info.get("longName", "N/A"), | |
"sector": company_info.get("sector", "N/A"), | |
"industry": company_info.get("industry", "N/A"), | |
"marketCap": company_info.get("marketCap", "N/A"), | |
"summary": company_info.get("longBusinessSummary", "N/A"), | |
"website": company_info.get("website", "N/A"), | |
"address": company_info.get("address1", "N/A"), | |
"city": company_info.get("city", "N/A"), | |
"state": company_info.get("state", "N/A"), | |
"country": company_info.get("country", "N/A"), | |
"phone": company_info.get("phone", "N/A") | |
} | |
except Exception as e: | |
logger.error(f"Error fetching company info for {symbol}: {e}") | |
return {"error": str(e)} | |
def format_company_info_for_gemini(company_info: Dict) -> str: | |
try: | |
if "error" in company_info: | |
return f"Error fetching company info: {company_info['error']}" | |
formatted_info = (f"\nCompany Information:\n" | |
f"Name: {company_info['name']}\n" | |
f"Sector: {company_info['sector']}\n" | |
f"Industry: {company_info['industry']}\n" | |
f"Market Cap: {company_info['marketCap']}\n" | |
f"Summary: {company_info['summary']}\n" | |
f"Website: {company_info['website']}\n" | |
f"Address: {company_info['address']}, {company_info['city']}, {company_info['state']}, {company_info['country']}\n" | |
f"Phone: {company_info['phone']}\n") | |
return formatted_info | |
except Exception as e: | |
logger.error(f"Error formatting company info for Gemini: {e}") | |
return "Error formatting company info." | |
def fetch_company_news_yahoo(symbol: str) -> List[Dict]: | |
try: | |
stock = yf.Ticker(symbol) | |
news = stock.news | |
if not news: | |
raise ValueError(f"No news found for symbol: {symbol}") | |
return news | |
except Exception as e: | |
logger.error(f"Failed to fetch news for {symbol} from Yahoo Finance: {e}") | |
return [] | |
def format_company_news_for_gemini(news: List[Dict]) -> str: | |
try: | |
if not news: | |
return "No news available." | |
formatted_news = "Latest company news:\n\n" | |
for article in news: | |
formatted_news += (f"Title: {article['title']}\n" | |
f"Publisher: {article['publisher']}\n" | |
f"Link: {article['link']}\n" | |
f"Published: {article['providerPublishTime']}\n\n") | |
return formatted_news | |
except Exception as e: | |
logger.error(f"Error formatting company news for Gemini: {e}") | |
return "Error formatting company news." | |
def send_to_gemini_for_summarization(content: str) -> str: | |
try: | |
unified_content = " ".join(content) | |
prompt = f"Summarize the main points of this article.\n\n{unified_content}" | |
response = llm.invoke(prompt) | |
return response.content.strip() | |
except Exception as e: | |
logger.error(f"Error sending content to Gemini for summarization: {e}") | |
return "Error summarizing content." | |
def answer_question_with_data(question: str, data: Dict) -> str: | |
try: | |
data_str = "" | |
for key, value in data.items(): | |
data_str += f"{key}:\n{value}\n\n" | |
prompt = (f"You are a financial advisor. Begin your answer by stating that and only give the answer after.\n" | |
f"Using the following data, answer this question: {question}\n\nData:\n{data_str}\n" | |
f"Make your answer in the best form and professional.\n" | |
f"Don't say anything about the source of the data.\n" | |
f"If you don't have the data to answer, say this data is not available yet. If the data is not available in the stock history data, say this was a weekend and there is no data for it.") | |
response = llm.invoke(prompt) | |
return response.content.strip() | |
except Exception as e: | |
logger.error(f"Error answering question with data: {e}") | |
return "Error answering question." | |
def format_google_results(google_results: Optional[Dict], summarizer: DataSummarizer, query: str) -> str: | |
try: | |
if google_results: | |
google_content = [summarizer.extract_content_from_item(item) for item in google_results.get('items', [])] | |
formatted_google_content = "\n\n".join(google_content) | |
else: | |
formatted_google_content = "No additional news found through Google Search." | |
snippet_query1 = f"{query} I want the answer only" | |
snippet_query2 = f"{query}" | |
google_snippet1 = summarizer.fetch_google_snippet(snippet_query1) | |
google_snippet2 = summarizer.fetch_google_snippet(snippet_query2) | |
google_snippet = google_snippet1 if google_snippet1 and google_snippet1 != "Snippet not found." else google_snippet2 | |
formatted_google_content += f"\n\nGoogle Snippet: {google_snippet}" | |
return formatted_google_content | |
except Exception as e: | |
logger.error(f"Error formatting Google results: {e}") | |
return "Error formatting Google results." | |
def calculate_metrics(stock_data: pd.DataFrame, summarizer: DataSummarizer, company_info: Dict) -> Dict[str, str]: | |
try: | |
moving_average = summarizer.calculate_moving_average(stock_data) | |
rsi = summarizer.calculate_rsi(stock_data) | |
ema = summarizer.calculate_ema(stock_data) | |
bollinger_bands = summarizer.calculate_bollinger_bands(stock_data) | |
macd = summarizer.calculate_macd(stock_data) | |
volatility = summarizer.calculate_volatility(stock_data) | |
atr = summarizer.calculate_atr(stock_data) | |
obv = summarizer.calculate_obv(stock_data) | |
yearly_summary = summarizer.calculate_yearly_summary(stock_data) | |
ytd_performance = summarizer.calculate_ytd_performance(stock_data) | |
eps = company_info.get('trailingEps', None) | |
if eps: | |
current_price = stock_data.iloc[-1]['close'] | |
pe_ratio = summarizer.calculate_pe_ratio(current_price, eps) | |
formatted_metrics = { | |
"Moving Average": moving_average.to_string(), | |
"RSI": rsi.to_string(), | |
"EMA": ema.to_string(), | |
"Bollinger Bands": bollinger_bands.to_string(), | |
"MACD": macd.to_string(), | |
"Volatility": volatility.to_string(), | |
"ATR": atr.to_string(), | |
"OBV": obv.to_string(), | |
"Yearly Summary": yearly_summary.to_string(), | |
"YTD Performance": f"{ytd_performance:.2f}%", | |
"P/E Ratio": f"{pe_ratio:.2f}" | |
} | |
else: | |
formatted_metrics = { | |
"Moving Average": moving_average.to_string(), | |
"RSI": rsi.to_string(), | |
"EMA": ema.to_string(), | |
"Bollinger Bands": bollinger_bands.to_string(), | |
"MACD": macd.to_string(), | |
"Volatility": volatility.to_string(), | |
"ATR": atr.to_string(), | |
"OBV": obv.to_string(), | |
"Yearly Summary": yearly_summary.to_string(), | |
"YTD Performance": f"{ytd_performance:.2f}%" | |
} | |
return formatted_metrics | |
except Exception as e: | |
logger.error(f"Error calculating metrics: {e}") | |
return {"Error": "Error calculating metrics"} | |
def prepare_data(formatted_stock_data: str, formatted_company_info: str, formatted_company_news: str, | |
summarized_google_content: str, formatted_metrics: Dict[str, str]) -> Dict[str, str]: | |
collected_data = { | |
"Formatted Stock Data": formatted_stock_data, | |
"Formatted Company Info": formatted_company_info, | |
"Formatted Company News": formatted_company_news, | |
"Google Search Results": summarized_google_content, | |
"Calculations": formatted_metrics | |
} | |
collected_data.update(formatted_metrics) | |
return collected_data | |
def translate_response(response: str, target_language: str) -> str: | |
try: | |
prompt = f"Translate the following text to {target_language}: {response}" | |
translation = llm.invoke(prompt) | |
return translation.content.strip() | |
except Exception as e: | |
logger.error(f"Error translating response: {e}") | |
return response # Return the original response if translation fails | |
def main(): | |
print("Welcome to the Financial Data Chatbot. How can I assist you today?") | |
summarizer = DataSummarizer() | |
conversation_history = [] | |
while True: | |
user_input = input("You: ") | |
if user_input.lower() in ['exit', 'quit', 'bye']: | |
print("Goodbye! Have a great day!") | |
break | |
conversation_history.append(f"You: {user_input}") | |
try: | |
# Detect language, entity, translation, and stock ticker | |
language, entity, translation, stock_ticker = detect_translate_entity_and_ticker(user_input) | |
if language and entity and translation and stock_ticker: | |
with ThreadPoolExecutor() as executor: | |
futures = { | |
executor.submit(fetch_stock_data_yahoo, stock_ticker): "stock_data", | |
executor.submit(fetch_company_info_yahoo, stock_ticker): "company_info", | |
executor.submit(fetch_company_news_yahoo, stock_ticker): "company_news", | |
executor.submit(fetch_current_stock_price, stock_ticker): "current_stock_price", | |
executor.submit(summarizer.google_search, f"{user_input} latest financial news"): "google_results" | |
} | |
results = {futures[future]: future.result() for future in as_completed(futures)} | |
stock_data = results["stock_data"] | |
formatted_stock_data = format_stock_data_for_gemini(stock_data) | |
company_info = results["company_info"] | |
formatted_company_info = format_company_info_for_gemini(company_info) | |
company_news = results["company_news"] | |
formatted_company_news = format_company_news_for_gemini(company_news) | |
current_stock_price = results["current_stock_price"] | |
google_results = results["google_results"] | |
formatted_google_content = format_google_results(google_results, summarizer, user_input) | |
summarized_google_content = send_to_gemini_for_summarization(formatted_google_content) | |
formatted_metrics = calculate_metrics(stock_data, summarizer, company_info) | |
collected_data = prepare_data(formatted_stock_data, formatted_company_info, formatted_company_news, | |
summarized_google_content, formatted_metrics) | |
collected_data["Current Stock Price"] = f"${current_stock_price:.2f}" if current_stock_price else "N/A" | |
rag_response = get_answer(user_input) | |
collected_data["RAG Response"] = rag_response | |
conversation_history.append(f"RAG Response: {rag_response}") | |
history_context = "\n".join(conversation_history) | |
answer = answer_question_with_data(f"{history_context}\n\nUser's query: {user_input}", collected_data) | |
if language != "English": | |
answer = translate_response(answer, language) | |
print(f"\nBot: {answer}") | |
conversation_history.append(f"Bot: {answer}") | |
else: | |
response = "I'm sorry, I couldn't process your request. Could you please rephrase?" | |
print(f"Bot: {response}") | |
conversation_history.append(f"Bot: {response}") | |
except Exception as e: | |
logger.error(f"An error occurred: {e}") | |
response = "An error occurred while processing your request. Please try again later." | |
print(f"Bot: {response}") | |
conversation_history.append(f"Bot: {response}") | |
if __name__ == "__main__": | |
main() | |