Spaces:
Runtime error
Runtime error
from flask import Flask, request, jsonify | |
import requests | |
from bs4 import BeautifulSoup | |
import yfinance as yf | |
import pandas as pd | |
from datetime import datetime, timedelta | |
import logging | |
from concurrent.futures import ThreadPoolExecutor, as_completed | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
from config import Config | |
import numpy as np | |
from typing import Optional, Tuple, List, Dict | |
from rag import get_answer | |
import time | |
from tenacity import retry, stop_after_attempt, wait_exponential | |
import threading | |
import streamlit as st | |
import json | |
# Initialize Flask app | |
app = Flask(__name__) | |
# Set up logging | |
logging.basicConfig(level=logging.DEBUG, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
handlers=[logging.FileHandler("app.log"), | |
logging.StreamHandler()]) | |
logger = logging.getLogger(__name__) | |
# Initialize the Gemini model | |
llm = ChatGoogleGenerativeAI(api_key=Config.GEMINI_API_KEY, model="gemini-1.5-flash-latest", temperature=0.5) | |
# Configuration for Google Custom Search API | |
GOOGLE_API_KEY = Config.GOOGLE_API_KEY | |
SEARCH_ENGINE_ID = Config.SEARCH_ENGINE_ID | |
def invoke_llm(prompt): | |
return llm.invoke(prompt) | |
class DataSummarizer: | |
def google_search(self, query: str) -> Optional[str]: | |
try: | |
url = "https://www.googleapis.com/customsearch/v1" | |
params = { | |
'key': GOOGLE_API_KEY, | |
'cx': SEARCH_ENGINE_ID, | |
'q': query | |
} | |
response = requests.get(url, params=params) | |
response.raise_for_status() | |
search_results = response.json() | |
items = search_results.get('items', []) | |
content = "\n\n".join([f"{item.get('title', '')}\n{item.get('snippet', '')}" for item in items]) | |
prompt = f"Summarize the following search results:\n\n{content}" | |
summary_response = invoke_llm(prompt) | |
return summary_response.content.strip() | |
except Exception as e: | |
logger.error(f"Error during Google Search API request: {e}") | |
return None | |
def extract_content_from_item(self, item: Dict) -> Optional[str]: | |
try: | |
snippet = item.get('snippet', '') | |
title = item.get('title', '') | |
return f"{title}\n{snippet}" | |
except Exception as e: | |
logger.error(f"Error extracting content from item: {e}") | |
return None | |
def calculate_moving_average(self, df: pd.DataFrame, window: int = 20) -> Optional[pd.Series]: | |
try: | |
result = df['close'].rolling(window=window).mean() | |
return result | |
except Exception as e: | |
logger.error(f"Error calculating moving average: {e}") | |
return None | |
def calculate_rsi(self, df: pd.DataFrame, window: int = 14) -> Optional[pd.Series]: | |
try: | |
delta = df['close'].diff() | |
gain = delta.where(delta > 0, 0).rolling(window=window).mean() | |
loss = -delta.where(delta < 0, 0).rolling(window=window).mean() | |
rs = gain / loss | |
result = 100 - (100 / (1 + rs)) | |
return result | |
except Exception as e: | |
logger.error(f"Error calculating RSI: {e}") | |
return None | |
def calculate_ema(self, df: pd.DataFrame, window: int = 20) -> Optional[pd.Series]: | |
try: | |
result = df['close'].ewm(span=window, adjust=False).mean() | |
return result | |
except Exception as e: | |
logger.error(f"Error calculating EMA: {e}") | |
return None | |
def calculate_bollinger_bands(self, df: pd.DataFrame, window: int = 20) -> Optional[pd.DataFrame]: | |
try: | |
ma = df['close'].rolling(window=window).mean() | |
std = df['close'].rolling(window=window).std() | |
upper_band = ma + (std * 2) | |
lower_band = ma - (std * 2) | |
result = pd.DataFrame({'MA': ma, 'Upper Band': upper_band, 'Lower Band': lower_band}) | |
return result | |
except Exception as e: | |
logger.error(f"Error calculating Bollinger Bands: {e}") | |
return None | |
def calculate_macd(self, df: pd.DataFrame, short_window: int = 12, long_window: int = 26, signal_window: int = 9) -> Optional[pd.DataFrame]: | |
try: | |
short_ema = df['close'].ewm(span=short_window, adjust=False).mean() | |
long_ema = df['close'].ewm(span=long_window, adjust=False).mean() | |
macd = short_ema - long_ema | |
signal = macd.ewm(span=signal_window, adjust=False).mean() | |
result = pd.DataFrame({'MACD': macd, 'Signal Line': signal}) | |
return result | |
except Exception as e: | |
logger.error(f"Error calculating MACD: {e}") | |
return None | |
def calculate_volatility(self, df: pd.DataFrame, window: int = 20) -> Optional[pd.Series]: | |
try: | |
log_returns = np.log(df['close'] / df['close'].shift(1)) | |
result = log_returns.rolling(window=window).std() * np.sqrt(window) | |
return result | |
except Exception as e: | |
logger.error(f"Error calculating volatility: {e}") | |
return None | |
def calculate_atr(self, df: pd.DataFrame, window: int = 14) -> Optional[pd.Series]: | |
try: | |
high_low = df['high'] - df['low'] | |
high_close = np.abs(df['high'] - df['close'].shift()) | |
low_close = np.abs(df['low'] - df['close'].shift()) | |
true_range = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1) | |
result = true_range.rolling(window=window).mean() | |
return result | |
except Exception as e: | |
logger.error(f"Error calculating ATR: {e}") | |
return None | |
def calculate_obv(self, df: pd.DataFrame) -> Optional[pd.Series]: | |
try: | |
result = (np.sign(df['close'].diff()) * df['volume']).fillna(0).cumsum() | |
return result | |
except Exception as e: | |
logger.error(f"Error calculating OBV: {e}") | |
return None | |
def calculate_yearly_summary(self, df: pd.DataFrame) -> Optional[pd.DataFrame]: | |
try: | |
df['year'] = pd.to_datetime(df['date']).dt.year | |
yearly_summary = df.groupby('year').agg({ | |
'close': ['mean', 'max', 'min'], | |
'volume': 'sum' | |
}) | |
yearly_summary.columns = ['_'.join(col) for col in yearly_summary.columns] | |
return yearly_summary | |
except Exception as e: | |
logger.error(f"Error calculating yearly summary: {e}") | |
return None | |
def get_full_last_year(self, df: pd.DataFrame) -> Optional[pd.DataFrame]: | |
try: | |
today = datetime.today().date() | |
last_year_start = datetime(today.year - 1, 1, 1).date() | |
last_year_end = datetime(today.year - 1, 12, 31).date() | |
mask = (df['date'] >= last_year_start) & (df['date'] <= last_year_end) | |
result = df.loc[mask] | |
return result | |
except Exception as e: | |
logger.error(f"Error filtering data for the last year: {e}") | |
return None | |
def calculate_ytd_performance(self, df: pd.DataFrame) -> Optional[float]: | |
try: | |
today = datetime.today().date() | |
year_start = datetime(today.year, 1, 1).date() | |
mask = (df['date'] >= year_start) & (df['date'] <= today) | |
ytd_data = df.loc[mask] | |
opening_price = ytd_data.iloc[0]['open'] | |
closing_price = ytd_data.iloc[-1]['close'] | |
result = ((closing_price - opening_price) / opening_price) * 100 | |
return result | |
except Exception as e: | |
logger.error(f"Error calculating YTD performance: {e}") | |
return None | |
def calculate_pe_ratio(self, current_price: float, eps: float) -> Optional[float]: | |
try: | |
if eps == 0: | |
raise ValueError("EPS cannot be zero for P/E ratio calculation.") | |
result = current_price / eps | |
return result | |
except Exception as e: | |
logger.error(f"Error calculating P/E ratio: {e}") | |
return None | |
def fetch_google_snippet(self, query: str) -> Optional[str]: | |
try: | |
search_url = f"https://www.google.com/search?q={query}" | |
headers = { | |
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.36" | |
} | |
response = requests.get(search_url, headers=headers) | |
soup = BeautifulSoup(response.text, 'html.parser') | |
snippet_classes = [ | |
'BNeawe iBp4i AP7Wnd', | |
'BNeawe s3v9rd AP7Wnd', | |
'BVG0Nb', | |
'kno-rdesc' | |
] | |
snippet = None | |
for cls in snippet_classes: | |
snippet = soup.find('div', class_=cls) | |
if snippet: | |
break | |
return snippet.get_text() if snippet else "Snippet not found." | |
except Exception as e: | |
logger.error(f"Error fetching Google snippet: {e}") | |
return None | |
def extract_ticker_from_response(response: str) -> Optional[str]: | |
try: | |
if "is **" in response and "**." in response: | |
return response.split("is **")[1].split("**.")[0].strip() | |
return response.strip() | |
except Exception as e: | |
logger.error(f"Error extracting ticker from response: {e}") | |
return None | |
def detect_translate_entity_and_ticker(query: str) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]: | |
try: | |
# Step 1: Detect Language | |
prompt = f"Detect the language for the following text: {query}" | |
response = invoke_llm(prompt) | |
detected_language = response.content.strip() | |
# Step 2: Translate to English (if necessary) | |
translated_query = query | |
if detected_language != "English": | |
prompt = f"Translate the following text to English: {query}" | |
response = invoke_llm(prompt) | |
translated_query = response.content.strip() | |
# Step 3: Detect Entity | |
prompt = f"Detect the entity in the following text that is a company name: {translated_query}" | |
response = invoke_llm(prompt) | |
detected_entity = response.content.strip() | |
if not detected_entity: | |
return detected_language, None, translated_query, None | |
# Step 4: Get Stock Ticker | |
prompt = f"What is the stock ticker symbol for the company {detected_entity}?" | |
response = invoke_llm(prompt) | |
stock_ticker = extract_ticker_from_response(response.content.strip()) | |
if not stock_ticker: | |
return detected_language, detected_entity, translated_query, None | |
return detected_language, detected_entity, translated_query, stock_ticker | |
except Exception as e: | |
logger.error(f"Error in detecting, translating, or extracting entity and ticker: {e}") | |
return None, None, None, None | |
def fetch_stock_data_yahoo(symbol: str) -> pd.DataFrame: | |
try: | |
stock = yf.Ticker(symbol) | |
end_date = datetime.now() | |
start_date = end_date - timedelta(days=3 * 365) | |
historical_data = stock.history(start=start_date, end=end_date) | |
historical_data = historical_data.rename(columns={"Open": "open", "High": "high", "Low": "low", "Close": "close", "Volume": "volume"}) | |
historical_data.reset_index(inplace=True) | |
historical_data['date'] = historical_data['Date'].dt.date | |
historical_data = historical_data.drop(columns=['Date']) | |
historical_data = historical_data[['date', 'open', 'high', 'low', 'close', 'volume']] | |
return historical_data | |
except Exception as e: | |
logger.error(f"Failed to fetch stock data for {symbol} from Yahoo Finance: {e}") | |
return pd.DataFrame() | |
def fetch_current_stock_price(symbol: str) -> Optional[float]: | |
try: | |
stock = yf.Ticker(symbol) | |
result = stock.info['currentPrice'] | |
return result | |
except Exception as e: | |
logger.error(f"Failed to fetch current stock price for {symbol}: {e}") | |
return None | |
def format_stock_data_for_gemini(stock_data: pd.DataFrame) -> str: | |
try: | |
if stock_data.empty: | |
return "No historical data available." | |
formatted_data = "Historical stock data for the last three years:\n\n" | |
formatted_data += "Date | Open | High | Low | Close | Volume\n" | |
formatted_data += "------------------------------------------------------\n" | |
for index, row in stock_data.iterrows(): | |
formatted_data += f"{row['date']} | {row['open']:.2f} | {row['high']:.2f} | {row['low']:.2f} | {row['close']:.2f} | {int(row['volume'])}\n" | |
return formatted_data | |
except Exception as e: | |
logger.error(f"Error formatting stock data for Gemini: {e}") | |
return "Error formatting stock data." | |
def fetch_company_info_yahoo(symbol: str) -> Dict: | |
try: | |
stock = yf.Ticker(symbol) | |
company_info = stock.info | |
return { | |
"name": company_info.get("longName", "N/A"), | |
"sector": company_info.get("sector", "N/A"), | |
"industry": company_info.get("industry", "N/A"), | |
"marketCap": company_info.get("marketCap", "N/A"), | |
"summary": company_info.get("longBusinessSummary", "N/A"), | |
"website": company_info.get("website", "N/A"), | |
"address": company_info.get("address1", "N/A"), | |
"city": company_info.get("city", "N/A"), | |
"state": company_info.get("state", "N/A"), | |
"country": company_info.get("country", "N/A"), | |
"phone": company_info.get("phone", "N/A") | |
} | |
except Exception as e: | |
logger.error(f"Error fetching company info for {symbol}: {e}") | |
return {"error": str(e)} | |
def format_company_info_for_gemini(company_info: Dict) -> str: | |
try: | |
if "error" in company_info: | |
return f"Error fetching company info: {company_info['error']}" | |
formatted_info = (f"\nCompany Information:\n" | |
f"Name: {company_info['name']}\n" | |
f"Sector: {company_info['sector']}\n" | |
f"Industry: {company_info['industry']}\n" | |
f"Market Cap: {company_info['marketCap']}\n" | |
f"Summary: {company_info['summary']}\n" | |
f"Website: {company_info['website']}\n" | |
f"Address: {company_info['address']}, {company_info['city']}, {company_info['state']}, {company_info['country']}\n" | |
f"Phone: {company_info['phone']}\n") | |
return formatted_info | |
except Exception as e: | |
logger.error(f"Error formatting company info for Gemini: {e}") | |
return "Error formatting company info." | |
def fetch_company_news_yahoo(symbol: str) -> List[Dict]: | |
try: | |
stock = yf.Ticker(symbol) | |
news = stock.news | |
return news if news else [] | |
except Exception as e: | |
logger.error(f"Failed to fetch news for {symbol} from Yahoo Finance: {e}") | |
return [] | |
def format_company_news_for_gemini(news: List[Dict]) -> str: | |
try: | |
if not news: | |
return "No news available." | |
formatted_news = "Latest company news:\n\n" | |
for article in news: | |
formatted_news += (f"Title: {article['title']}\n" | |
f"Publisher: {article['publisher']}\n" | |
f"Link: {article['link']}\n" | |
f"Published: {article['providerPublishTime']}\n\n") | |
return formatted_news | |
except Exception as e: | |
logger.error(f"Error formatting company news for Gemini: {e}") | |
return "Error formatting company news." | |
def send_to_gemini_for_summarization(content: str) -> str: | |
try: | |
unified_content = " ".join(content) | |
prompt = f"Summarize the main points of this article.\n\n{unified_content}" | |
response = invoke_llm(prompt) | |
return response.content.strip() | |
except Exception as e: | |
logger.error(f"Error sending content to Gemini for summarization: {e}") | |
return "Error summarizing content." | |
def answer_question_with_data(question: str, data: Dict) -> str: | |
try: | |
data_str = "" | |
for key, value in data.items(): | |
data_str += f"{key}:\n{value}\n\n" | |
prompt = (f"You are a financial advisor. Begin your answer and only give the answer after.\n" | |
f"Using the following data, answer this question: {question}\n\nData:\n{data_str}\n" | |
f"Make your answer in the best form and professional.\n" | |
f"Don't say anything about the source of the data.\n" | |
f"If you don't have the data to answer, say this data is not available yet. If the data is not available in the stock history data, say this was a weekend and there is no data for it.") | |
response = invoke_llm(prompt) | |
return response.content.strip() | |
except Exception as e: | |
logger.error(f"Error answering question with data: {e}") | |
return "Error answering question." | |
def calculate_metrics(stock_data: pd.DataFrame, summarizer: DataSummarizer, company_info: Dict) -> Dict[str, str]: | |
try: | |
moving_average = summarizer.calculate_moving_average(stock_data) | |
rsi = summarizer.calculate_rsi(stock_data) | |
ema = summarizer.calculate_ema(stock_data) | |
bollinger_bands = summarizer.calculate_bollinger_bands(stock_data) | |
macd = summarizer.calculate_macd(stock_data) | |
volatility = summarizer.calculate_volatility(stock_data) | |
atr = summarizer.calculate_atr(stock_data) | |
obv = summarizer.calculate_obv(stock_data) | |
yearly_summary = summarizer.calculate_yearly_summary(stock_data) | |
ytd_performance = summarizer.calculate_ytd_performance(stock_data) | |
eps = company_info.get('trailingEps', None) | |
if eps: | |
current_price = stock_data.iloc[-1]['close'] | |
pe_ratio = summarizer.calculate_pe_ratio(current_price, eps) | |
formatted_metrics = { | |
"Moving Average": moving_average.to_string(), | |
"RSI": rsi.to_string(), | |
"EMA": ema.to_string(), | |
"Bollinger Bands": bollinger_bands.to_string(), | |
"MACD": macd.to_string(), | |
"Volatility": volatility.to_string(), | |
"ATR": atr.to_string(), | |
"OBV": obv.to_string(), | |
"Yearly Summary": yearly_summary.to_string(), | |
"YTD Performance": f"{ytd_performance:.2f}%", | |
"P/E Ratio": f"{pe_ratio:.2f}" | |
} | |
else: | |
formatted_metrics = { | |
"Moving Average": moving_average.to_string(), | |
"RSI": rsi.to_string(), | |
"EMA": ema.to_string(), | |
"Bollinger Bands": bollinger_bands.to_string(), | |
"MACD": macd.to_string(), | |
"Volatility": volatility.to_string(), | |
"ATR": atr.to_string(), | |
"OBV": obv.to_string(), | |
"Yearly Summary": yearly_summary.to_string(), | |
"YTD Performance": f"{ytd_performance:.2f}%" | |
} | |
return formatted_metrics | |
except Exception as e: | |
logger.error(f"Error calculating metrics: {e}") | |
return {"Error": "Error calculating metrics"} | |
def prepare_data(formatted_stock_data: str, formatted_company_info: str, formatted_company_news: str, | |
google_results: str, formatted_metrics: Dict[str, str], google_snippet: str, rag_response: str) -> Dict[str, str]: | |
collected_data = { | |
"Formatted Stock Data": formatted_stock_data, | |
"Formatted Company Info": formatted_company_info, | |
"Formatted Company News": formatted_company_news, | |
"Google Search Results": google_results, | |
"Google Snippet": google_snippet, | |
"RAG Response": rag_response, | |
"Calculations": formatted_metrics | |
} | |
collected_data.update(formatted_metrics) | |
return collected_data | |
def ask(): | |
try: | |
user_input = request.json.get('question') | |
summarizer = DataSummarizer() | |
language, entity, translation, stock_ticker = detect_translate_entity_and_ticker(user_input) | |
if entity and stock_ticker: | |
with ThreadPoolExecutor() as executor: | |
futures = { | |
executor.submit(fetch_stock_data_yahoo, stock_ticker): "stock_data", | |
executor.submit(fetch_company_info_yahoo, stock_ticker): "company_info", | |
executor.submit(fetch_company_news_yahoo, stock_ticker): "company_news", | |
executor.submit(fetch_current_stock_price, stock_ticker): "current_stock_price", | |
executor.submit(get_answer, user_input): "rag_response", | |
executor.submit(summarizer.google_search, user_input): "google_results", | |
executor.submit(summarizer.fetch_google_snippet, user_input): "google_snippet" | |
} | |
results = {futures[future]: future.result() for future in as_completed(futures)} | |
stock_data = results.get("stock_data", pd.DataFrame()) | |
formatted_stock_data = format_stock_data_for_gemini(stock_data) if not stock_data.empty else "No historical data available." | |
company_info = results.get("company_info", {}) | |
formatted_company_info = format_company_info_for_gemini(company_info) if company_info else "No company info available." | |
company_news = results.get("company_news", []) | |
formatted_company_news = format_company_news_for_gemini(company_news) if company_news else "No news available." | |
current_stock_price = results.get("current_stock_price", None) | |
formatted_metrics = calculate_metrics(stock_data, summarizer, company_info) if not stock_data.empty else {"Error": "No stock data for metrics"} | |
google_results = results.get("google_results", "No additional news found through Google Search.") | |
google_snippet = results.get("google_snippet", "Snippet not found.") | |
rag_response = results.get("rag_response", "No response from RAG.") | |
collected_data = prepare_data(formatted_stock_data, formatted_company_info, formatted_company_news, google_results, formatted_metrics, google_snippet, rag_response) | |
collected_data["Current Stock Price"] = f"${current_stock_price:.2f}" if current_stock_price is not None else "N/A" | |
answer = answer_question_with_data(f"{translation}", collected_data) | |
return jsonify({"answer": answer}) | |
else: | |
with ThreadPoolExecutor() as executor: | |
futures = { | |
executor.submit(get_answer, user_input): "rag_response", | |
executor.submit(summarizer.google_search, user_input): "google_results", | |
executor.submit(summarizer.fetch_google_snippet, user_input): "google_snippet" | |
} | |
results = {futures[future]: future.result() for future in as_completed(futures)} | |
google_results = results.get("google_results", "No additional news found through Google Search.") | |
google_snippet = results.get("google_snippet", "Snippet not found.") | |
rag_response = results.get("rag_response", "No response from RAG.") | |
collected_data = prepare_data("", "", "", google_results, {}, google_snippet, rag_response) | |
answer = answer_question_with_data(f"{user_input}", collected_data) | |
return jsonify({"answer": answer}) | |
except Exception as e: | |
logger.error(f"An error occurred: {e}") | |
return jsonify({"error": "An error occurred while processing your request. Please try again later."}), 500 | |
# Streamlit App | |
def send_question_to_api(question): | |
url = 'http://localhost:5000/ask' | |
headers = {'Content-Type': 'application/json'} | |
data = {'question': question} | |
response = requests.post(url, headers=headers, data=json.dumps(data)) | |
if response.status_code == 200: | |
return response.json().get('answer') | |
else: | |
return f"Error: {response.status_code} - {response.text}" | |
def run_streamlit(): | |
st.title("Financial Data Chatbot Tester") | |
st.write("Enter your question below and get a response from the chatbot.") | |
if 'history' not in st.session_state: | |
st.session_state.history = [] | |
user_input = st.text_input("Your question:", "") | |
if st.button("Submit"): | |
if user_input: | |
with st.spinner('Getting the answer...'): | |
answer = send_question_to_api(user_input) | |
st.session_state.history.append((user_input, answer)) | |
st.success(answer) | |
else: | |
st.warning("Please enter a question before submitting.") | |
if st.session_state.history: | |
st.write("### History") | |
for idx, (question, answer) in enumerate(st.session_state.history, 1): | |
st.write(f"**Q{idx}:** {question}") | |
st.write(f"**A{idx}:** {answer}") | |
st.write("---") | |
if __name__ == '__main__': | |
threading.Thread(target=lambda: app.run(host='0.0.0.0', port=7860)).start() | |
run_streamlit() | |