Spaces:
Sleeping
Sleeping
import os | |
import streamlit as st | |
import yfinance as yf | |
import pandas as pd | |
from langchain.agents import create_csv_agent, AgentType | |
from langchain.chat_models import ChatOpenAI | |
from htmlTemplates import css, user_template, bot_template | |
from langchain.llms.base import LLM | |
from typing import Optional, List | |
import g4f | |
os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY') | |
print("API Key:", os.getenv('OPENAI_API_KEY')) | |
llm = ChatOpenAI( | |
model='gpt-3.5-turbo', | |
max_tokens=500, | |
temperature=0.7, | |
) | |
def init_ses_states(): | |
st.session_state.setdefault('chat_history', []) | |
def relative_returns(df): | |
rel = df.pct_change() | |
cumret = ((1 + rel).cumprod() - 1).fillna(0) | |
return cumret | |
def display_convo(): | |
with st.container(): | |
for i, message in enumerate(reversed(st.session_state.chat_history)): | |
if i % 2 == 0: | |
st.markdown(bot_template.replace("{{MSG}}", message), unsafe_allow_html=True) | |
else: | |
st.markdown(user_template.replace("{{MSG}}", message), unsafe_allow_html=True) | |
def main(): | |
st.set_page_config(page_title="Stock Price AI Bot", page_icon=":chart:") | |
st.write(css, unsafe_allow_html=True) | |
init_ses_states() | |
st.title("Stock Price AI Bot") | |
st.caption("Visualizations and OpenAI Chatbot for Multiple Stocks Over A Specified Period") | |
with st.sidebar: | |
asset_tickers = sorted(['DOW', 'NVDA', 'TSL', 'GOOGL', 'AMZN', 'AI', 'NIO', 'LCID', 'F', 'LYFY', 'AAPL', 'MSFT', 'BTC-USD', 'ETH-USD']) | |
asset_dropdown = st.multiselect('Pick Assets:', asset_tickers) | |
metric_tickers = ['Adj. Close', 'Relative Returns'] | |
metric_dropdown = st.selectbox("Metric", metric_tickers) | |
viz_tickers = ['Line Chart', 'Area Chart'] | |
viz_dropdown = st.multiselect("Pick Charts:", viz_tickers) | |
start = st.date_input('Start', value=pd.to_datetime('2023-01-01')) | |
end = st.date_input('End', value=pd.to_datetime('today')) | |
if len(asset_dropdown) > 0: | |
df = yf.download(asset_dropdown, start, end)['Adj Close'] | |
if metric_dropdown == 'Relative Returns': | |
df = relative_returns(df) | |
if len(viz_dropdown) > 0: | |
with st.expander("Data Visualizations", expanded=True): | |
if "Line Chart" in viz_dropdown: | |
st.line_chart(df) | |
if "Area Chart" in viz_dropdown: | |
st.area_chart(df) | |
st.header("Chat with your Data") | |
query = st.text_input("Enter a query:") | |
chat_prompt = f''' | |
You are an AI ChatBot intended to help with user stock data. | |
\nDATA MODE: {metric_dropdown} | |
\nSTOCKS: {asset_dropdown} | |
\nTIME PERIOD: {start} to {end} | |
\nCHAT HISTORY: {st.session_state.chat_history} | |
\nUSER MESSAGE: {query} | |
\nAI RESPONSE HERE: | |
''' | |
if st.button("Execute") and query: | |
with st.spinner('Generating response...'): | |
try: | |
DF = pd.DataFrame(df) | |
DF.to_csv('data.csv') | |
agent = create_csv_agent( | |
llm, | |
'data.csv', | |
verbose=True, | |
agent_type=AgentType.OPENAI_FUNCTIONS, | |
) | |
answer = agent.run(chat_prompt) | |
st.session_state.chat_history.append(f"USER: {query}\n") | |
st.session_state.chat_history.append(f"AI: {answer}\n") | |
display_convo() | |
except Exception as e: | |
st.error(f"An error occurred: {str(e)}") | |
if __name__ == '__main__': | |
main() | |