import streamlit as st import yfinance as yf import requests import os from dotenv import load_dotenv from langchain.agents import Tool, AgentExecutor, LLMSingleActionAgent, AgentOutputParser from langchain.schema import AgentAction, AgentFinish, HumanMessage from langchain.prompts import BaseChatPromptTemplate from langchain.tools import Tool from langchain_huggingface import HuggingFacePipeline from langchain import LLMChain from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline from langchain.memory import ConversationBufferWindowMemory from statsmodels.tsa.arima.model import ARIMA import torch import re from typing import List, Union # Load environment variables from .env load_dotenv() NEWSAPI_KEY = os.getenv("NEWSAPI_KEY") access_token = os.getenv("API_KEY") # Check if the access token and API key are present if not NEWSAPI_KEY or not access_token: raise ValueError("NEWSAPI_KEY or API_KEY not found in .env file.") # Initialize the model and tokenizer for the HuggingFace pipeline tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it", token=access_token) model = AutoModelForCausalLM.from_pretrained( "google/gemma-2b-it", torch_dtype=torch.bfloat16, token=access_token ) pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=512) # Define functions for fetching stock data, news, and moving averages def validate_ticker(ticker): return ticker.strip().upper() def fetch_stock_data(ticker): try: ticker = ticker.strip().upper() stock = yf.Ticker(ticker) hist = stock.history(period="1mo") if hist.empty: return {"error": f"No data found for ticker {ticker}"} return hist.tail(5).to_dict() except Exception as e: return {"error": str(e)} def fetch_stock_news(ticker, NEWSAPI_KEY): api_url = f"https://newsapi.org/v2/everything?q={ticker}&apiKey={NEWSAPI_KEY}" response = requests.get(api_url) if response.status_code == 200: articles = response.json().get('articles', []) return [{"title": article['title'], "description": article['description']} for article in articles[:5]] else: return [{"error": "Unable to fetch news."}] def calculate_moving_average(ticker, window=5): stock = yf.Ticker(ticker) hist = stock.history(period="1mo") hist[f"{window}-day MA"] = hist["Close"].rolling(window=window).mean() return hist[["Close", f"{window}-day MA"]].tail(5) def analyze_sentiment(news_articles): sentiment_pipeline = pipeline("sentiment-analysis") results = [{"title": article["title"], "sentiment": sentiment_pipeline(article["description"] or article["title"])[0]} for article in news_articles] return results def predict_stock_price(ticker, days=5): stock = yf.Ticker(ticker) hist = stock.history(period="6mo") if hist.empty: return {"error": f"No data found for ticker {ticker}"} model = ARIMA(hist["Close"], order=(5, 1, 0)) model_fit = model.fit() forecast = model_fit.forecast(steps=days) return forecast.tolist() def compare_stocks(ticker1, ticker2): data1 = fetch_stock_data(ticker1) data2 = fetch_stock_data(ticker2) if "error" in data1 or "error" in data2: return {"error": "Could not fetch stock data for comparison."} comparison = { ticker1: {"recent_close": data1["Close"][-1]}, ticker2: {"recent_close": data2["Close"][-1]}, } return comparison # Define LangChain tools stock_data_tool = Tool( name="Stock Data Fetcher", func=fetch_stock_data, description="Fetch recent stock data for a valid stock ticker symbol (e.g., AAPL for Apple)." ) stock_news_tool = Tool( name="Stock News Fetcher", func=lambda ticker: fetch_stock_news(ticker, NEWSAPI_KEY), description="Fetch recent news articles about a stock ticker." ) moving_average_tool = Tool( name="Moving Average Calculator", func=calculate_moving_average, description="Calculate the moving average of a stock over a 5-day window." ) sentiment_tool = Tool( name="News Sentiment Analyzer", func=lambda ticker: analyze_sentiment(fetch_stock_news(ticker, NEWSAPI_KEY)), description="Analyze the sentiment of recent news articles about a stock ticker." ) stock_prediction_tool = Tool( name="Stock Price Predictor", func=predict_stock_price, description="Predict future stock prices for a given ticker based on historical data." ) stock_comparator_tool = Tool( name="Stock Comparator", func=lambda tickers: compare_stocks(*tickers.split(',')), description="Compare the recent performance of two stocks given their tickers, e.g., 'AAPL,MSFT'." ) tools = [ stock_data_tool, stock_news_tool, moving_average_tool, sentiment_tool, stock_prediction_tool, stock_comparator_tool ] # Set up a prompt template with history template_with_history = """You are SearchGPT, a professional search engine who provides informative answers to users. Answer the following questions as best you can. You have access to the following tools: {tools} Use the following format: Question: the input question you must answer Thought: you should always think about what to do Action: the action to take, should be one of [{tool_names}] Action Input: the input to the action Observation: the result of the action ... (this Thought/Action/Action Input/Observation can repeat N times) Thought: I now know the final answer Final Answer: the final answer to the original input question Begin! Remember to give detailed, informative answers Previous conversation history: {history} New question: {input} {agent_scratchpad}""" # Set up the prompt template class CustomPromptTemplate(BaseChatPromptTemplate): template: str tools: List[Tool] def format_messages(self, **kwargs) -> str: intermediate_steps = kwargs.pop("intermediate_steps") thoughts = "" for action, observation in intermediate_steps: thoughts += action.log thoughts += f"\nObservation: {observation}\nThought: " kwargs["agent_scratchpad"] = thoughts kwargs["tools"] = "\n".join([f"{tool.name}: {tool.description}" for tool in self.tools]) kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools]) formatted = self.template.format(**kwargs) return [HumanMessage(content=formatted)] prompt_with_history = CustomPromptTemplate( template=template_with_history, tools=tools, input_variables=["input", "intermediate_steps", "history"] ) # Custom output parser class CustomOutputParser(AgentOutputParser): def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]: if "Final Answer:" in llm_output: return AgentFinish( return_values={"output": llm_output.split("Final Answer:")[-1].strip()}, log=llm_output, ) regex = r"Action: (.*?)[\n]*Action Input:[\s]*(.*)" match = re.search(regex, llm_output, re.DOTALL) if not match: raise ValueError(f"Could not parse LLM output: `{llm_output}`") action = match.group(1).strip() action_input = match.group(2) return AgentAction(tool=action, tool_input=action_input.strip(" ").strip('"'), log=llm_output) output_parser = CustomOutputParser() # Initialize HuggingFace pipeline llm = HuggingFacePipeline(pipeline=pipe) # LLM chain llm_chain = LLMChain(llm=llm, prompt=prompt_with_history) tool_names = [tool.name for tool in tools] agent = LLMSingleActionAgent( llm_chain=llm_chain, output_parser=output_parser, stop=["\nObservation:"], allowed_tools=tool_names ) memory = ConversationBufferWindowMemory(k=2) agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True, memory=memory) # Streamlit app st.title("Trading Helper Agent") query = st.text_input("Enter your query:") if st.button("Submit"): if query: st.write("Debug: User Query ->", query) with st.spinner("Processing..."): try: # Run the agent and get the response response = agent_executor.run(query) # Correct method is `run()` st.success("Response:") st.write(response) except Exception as e: st.error(f"An error occurred: {e}") # Log the full LLM output for debugging if hasattr(e, "output"): st.write("Raw Output:", e.output) else: st.warning("Please enter a query.")