ALM_LLM / app.py
AshenH's picture
Update app.py
9b3d9a0 verified
raw
history blame
8.95 kB
import gradio as gr
import pandas as pd
import numpy as np
import os
import statsmodels.api as sm
from io import StringIO
# --- LangChain Imports ---
from langchain_groq import ChatGroq
from langchain.agents import AgentExecutor, create_tool_calling_agent
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.tools import tool
from langchain_core.messages import SystemMessage
# --- ASSUMPTION ---
# Assuming you have a file named 'sql_tools.py' in the same directory
# with your pre-built and decorated @tool functions.
try:
from sql_tools import run_duckdb_query, get_table_schema
except ImportError:
print("WARNING: Could not import from 'sql_tools.py'.")
print("Using placeholder functions. Please create 'sql_tools.py'.")
# Create placeholder tools if the file is missing, so the app can start
@tool
def run_duckdb_query(query: str) -> str:
"""
[PLACEHOLDER] Runs a read-only SQL query.
Please create sql_tools.py to implement this.
"""
if "schema" in query.lower() or "describe" in query.lower():
return "report_date DATE, portfolio_id VARCHAR, sector VARCHAR, market_value_usd DOUBLE"
return "Error: 'sql_tools.py' not found. This is a placeholder."
@tool
def get_table_schema(table_name: str = "positions") -> str:
"""
[PLACEHOLDER] Returns the schema for the 'positions' table.
Please create sql_tools.py to implement this.
"""
return "report_date DATE, portfolio_id VARCHAR, sector VARCHAR, market_value_usd DOUBLE"
# --- Agent Tools ---
# These tools perform analysis on data *after* it has been fetched.
@tool
def calculate_summary_statistics_from_data(data_string: str, column: str) -> str:
"""
Calculates summary statistics (mean, median, std, min, max) for a specific
'column' from a 'data_string'.
'data_string' should be the string output from the `run_duckdb_query` tool.
"""
try:
# Convert the string data back into a DataFrame
data_df = pd.read_csv(StringIO(data_string.strip()), delim_whitespace=True, header=0)
# HACK: The string output might have an extra index column, let's find the real columns
if column not in data_df.columns:
# Try reading again, assuming first column is an unnamed index
data_df = pd.read_csv(StringIO(data_string.strip()), delim_whitespace=True, header=0, index_col=0)
if column not in data_df.columns:
return f"Error: Column '{column}' not found in data."
stats = {
"column": column,
"mean": data_df[column].mean(),
"median": data_df[column].median(),
"std_dev": data_df[column].std(),
"min": data_df[column].min(),
"max": data_df[column].max(),
"count": data_df[column].count()
}
return str(stats)
except Exception as e:
return f"Error in calculate_summary_statistics: {e}. Data input was: '{data_string[:200]}...'"
@tool
def perform_arima_forecast_from_data(data_string: str, time_column: str, value_column: str, forecast_periods: int) -> str:
"""
Performs an ARIMA(1,1,1) forecast on a 'data_string'.
'data_string': The string output from `run_duckdb_query`.
'time_column': The name of the date/time column in the data.
'value_column': The name of the numerical column to forecast.
'forecast_periods': The number of periods (e.g., days) to forecast.
The data MUST be ordered by the time_column before being passed to this tool.
"""
try:
# Convert the string data back into a DataFrame
data_df = pd.read_csv(StringIO(data_string.strip()), delim_whitespace=True, header=0)
# HACK: The string output might have an extra index column
if time_column not in data_df.columns:
data_df = pd.read_csv(StringIO(data_string.strip()), delim_whitespace=True, header=0, index_col=0)
if time_column not in data_df.columns:
return f"Error: Time column '{time_column}' not found in data."
if value_column not in data_df.columns:
return f"Error: Value column '{value_column}' not found in data."
if data_df.empty:
return "Error: Query returned no data."
# Prepare data for statsmodels
data_df[time_column] = pd.to_datetime(data_df[time_column])
data_df = data_df.set_index(time_column)
data_df = data_df.asfreq('D') # Ensure daily frequency, fill gaps if any
data_df[value_column] = data_df[value_column].fillna(method='ffill')
model = sm.tsa.ARIMA(data_df[value_column], order=(1, 1, 1))
results = model.fit()
forecast = results.forecast(steps=forecast_periods)
forecast_df = pd.DataFrame({
'date': forecast.index.strftime('%Y-%m-%d'),
'forecasted_value': forecast.values
})
return f"Forecast successful. Last historical value was {data_df[value_column].iloc[-1]:.2f}.\nForecast:\n{forecast_df.to_string()}"
except Exception as e:
return f"Error in perform_arima_forecast: {e}. Data input was: '{data_string[:200]}...'"
# --- Main Agent and UI Setup ---
# Check for the GROQ_API_KEY in Hugging Face Space Secrets
if "GROQ_API_KEY" not in os.environ:
print("GROQ_API_KEY not found in secrets!")
def missing_key_error(message, history):
return "Error: `GROQ_API_KEY` is not set in this Space's Secrets. Please add it to use the app."
gr.ChatInterface(
missing_key_error,
title="Agentic Portfolio Analyst",
description="Error: GROQ_API_KEY secret is missing."
).launch()
else:
print("GROQ_API_KEY found. Initializing agent...")
llm = ChatGroq(model_name="llama-3.3-70b-versatile")
# 2. Collect all our tools (imported and local)
tools = [
run_duckdb_query,
get_table_schema,
calculate_summary_statistics_from_data,
perform_arima_forecast_from_data
]
# 3. Create the Agent Prompt
system_prompt = """
You are an expert portfolio analyst. You have access to SQL tools and analysis tools.
Your logic MUST follow these steps:
1. Use `get_table_schema` to understand the data.
2. Use `run_duckdb_query` to fetch the raw data you need.
3. If analysis (statistics or forecasting) is needed, take the string output
from `run_duckdb_query` and pass it *directly* to either
`calculate_summary_statistics_from_data` or `perform_arima_forecast_from_data`.
Example for forecasting:
1. Call `run_duckdb_query("SELECT report_date, SUM(market_value_usd) AS total_value FROM positions WHERE sector = 'Tech' GROUP BY report_date ORDER BY report_date")`.
2. Get the result string: " report_date total_value \n 2024-01-01 100000.0 \n 2024-01-02 100500.0 \n ..."
3. Call `perform_arima_forecast_from_data(data_string=" report_date total_value \n 2024-01-01 100000.0 \n ...", time_column="report_date", value_column="total_value", forecast_periods=30)`.
Answer the user's request based on the final tool output.
"""
prompt = ChatPromptTemplate.from_messages(
[
SystemMessage(content=system_prompt),
("placeholder", "{chat_history}"),
("human", "{input}"),
("placeholder", "{agent_scratchpad}"),
]
)
# 4. Create the Agent
agent = create_tool_calling_agent(llm, tools, prompt)
# 5. Create the Agent Executor
agent_executor = AgentExecutor(
agent=agent,
tools=tools,
verbose=True
)
# 6. Define the function for Gradio
def run_agent(message, history):
chat_history = []
for human_msg, ai_msg in history:
chat_history.append(("human", human_msg))
chat_history.append(("ai", ai_msg))
try:
response = agent_executor.invoke({
"input": message,
"chat_history": chat_history
})
return response["output"]
except Exception as e:
return f"An error occurred: {e}"
# 7. Launch the Gradio App
gr.ChatInterface(
run_agent,
title="Agentic Portfolio Analyst",
description="Ask me questions about your portfolio. (This app uses imported SQL tools).",
examples=[
"What is the schema of the positions table?",
"What's the total market value by sector on the last available date?",
"Give me summary statistics for the 'Tech' sector's market value from portfolio P-123. Use the 'market_value_usd' column for stats.",
"What is the 30-day forecast for the total market value of portfolio P-123? Use 'total_value' for the forecast value column."
]
).launch()