temp / DV-AGENT /tools.py
NEXAS's picture
Upload 22 files
182219d verified
raw
history blame contribute delete
No virus
3.8 kB
from langchain_experimental.agents.agent_toolkits import create_csv_agent
from langchain.agents.agent_types import AgentType
from langchain.agents import Tool
from langchain.chains import LLMMathChain
import streamlit as st
import pandas as pd
import plotly.express as px
import os
import streamlit as st
#from langchain_community.llms import HuggingFaceHub
from typing import List
from langchain_groq import ChatGroq
from dotenv import load_dotenv
load_dotenv()
groq_api_key = os.getenv("GROQ_API_KEY")
llm1 = ChatGroq(temperature=0, model_name="mixtral-8x7b-32768")
def csv_agnet(string):
agent = create_csv_agent(
llm1,
"dataset.csv",
verbose=True,
agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
)
ans = agent.invoke(string)
return ans
#def csv_agnet(string):
# agent = create_csv_agent(
# ChatOpenAI(temperature=0, model="gpt-3.5-turbo-0613"),
# "dataset.csv",
# verbose=True,
# agent_type=AgentType.OPENAI_FUNCTIONS,
# )
# ans = agent.run(string)
# return ans
def math_tool(string):
#llm = OpenAI(temperature=0)
llm = llm1
llm_math_chain = LLMMathChain(llm=llm, verbose=True)
res = llm_math_chain.run(string)
return res
def load_data():
df = pd.read_csv("dataset.csv", encoding="utf-8")
return df
def plot_visualization(selected_option, x_column, y_column):
df = load_data()
if df.empty:
return st.warning("The data is empty.")
if x_column not in df.columns or y_column not in df.columns:
return st.warning("Invalid columns selected.")
if selected_option == "bar":
fig = px.bar(df, x=x_column, y=y_column, title=f"{x_column} vs {y_column}")
elif selected_option == "scatter":
fig = px.scatter(df, x=x_column, y=y_column, title=f"{x_column} vs {y_column}")
elif selected_option == "line":
fig = px.line(df, x=x_column, y=y_column, title=f"{x_column} vs {y_column}")
elif selected_option == "scatter_matrix":
fig = px.scatter_matrix(df, dimensions=[x_column, y_column], title=f"Scatter Matrix: {x_column} vs {y_column}")
elif selected_option == "box":
fig = px.box(df, x=x_column, y=y_column, title=f"Box Plot: {x_column} vs {y_column}")
elif selected_option == "heatmap":
fig = px.imshow(df.pivot_table(index=x_column, columns=y_column, aggfunc='size').fillna(0),
labels=dict(x=x_column, y=y_column),
title=f"Heatmap: {x_column} vs {y_column}")
else:
return st.warning("Please select a valid plot type.")
return st.plotly_chart(fig)
def parsing_input(string):
selected_option, x_column, y_column = string.split(",")
return plot_visualization(selected_option, x_column, y_column)
zeroshot_tools = [
Tool(
name="answer_qa",
func=csv_agnet,
description="Use this tool to query the dataset. input to this tool should be a standalone question. Include the correct row titles that are needed. Example Input format: How many rows are there in the dataset, which name has the highest calories",
#return_direct=True,
),
Tool(
name="create_simple_plot",
func=parsing_input,
description="""Use this tool if the user asks to create x vs y plots. input must be a comma seperated list of: selected_option, x_column, y_column
Example Inputs:
bar,calories,name
Allowed options are: bar, line, scatter_matrix, box, heatmap
you can decide plot type, x colllumn and y collumn based on the user input.
""",
#return_direct=True,
),
Tool(
name="Calculator",
func=math_tool,
description="useful when you need to do calculations. Example input: 21^0.43"
),
]