File size: 3,801 Bytes
182219d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from langchain_experimental.agents.agent_toolkits import create_csv_agent
from langchain.agents.agent_types import AgentType
from langchain.agents import Tool
from langchain.chains import LLMMathChain
import streamlit as st
import pandas as pd
import plotly.express as px
import os
import streamlit as st
#from langchain_community.llms import HuggingFaceHub
from typing import List
from langchain_groq import ChatGroq
from dotenv import load_dotenv
load_dotenv()

groq_api_key = os.getenv("GROQ_API_KEY")

llm1 = ChatGroq(temperature=0, model_name="mixtral-8x7b-32768")

def csv_agnet(string):
    agent = create_csv_agent(
        llm1,
        "dataset.csv",
        verbose=True,
        agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
    )

    ans = agent.invoke(string)
    return ans

#def csv_agnet(string):
#    agent = create_csv_agent(
#        ChatOpenAI(temperature=0, model="gpt-3.5-turbo-0613"),
#        "dataset.csv",
#        verbose=True,
#        agent_type=AgentType.OPENAI_FUNCTIONS,
#    )

#    ans = agent.run(string)
#    return ans

def math_tool(string):
    #llm = OpenAI(temperature=0)
    llm = llm1
    llm_math_chain = LLMMathChain(llm=llm, verbose=True)
    res = llm_math_chain.run(string)
    return res

def load_data():
    df = pd.read_csv("dataset.csv", encoding="utf-8")
    return df

def plot_visualization(selected_option, x_column, y_column):
    df = load_data()

    if df.empty:
        return st.warning("The data is empty.")

    if x_column not in df.columns or y_column not in df.columns:
        return st.warning("Invalid columns selected.")

    if selected_option == "bar":
        fig = px.bar(df, x=x_column, y=y_column, title=f"{x_column} vs {y_column}")
    elif selected_option == "scatter":
        fig = px.scatter(df, x=x_column, y=y_column, title=f"{x_column} vs {y_column}")
    elif selected_option == "line":
        fig = px.line(df, x=x_column, y=y_column, title=f"{x_column} vs {y_column}")
    elif selected_option == "scatter_matrix":
        fig = px.scatter_matrix(df, dimensions=[x_column, y_column], title=f"Scatter Matrix: {x_column} vs {y_column}")
    elif selected_option == "box":
        fig = px.box(df, x=x_column, y=y_column, title=f"Box Plot: {x_column} vs {y_column}")
    elif selected_option == "heatmap":
        fig = px.imshow(df.pivot_table(index=x_column, columns=y_column, aggfunc='size').fillna(0),
                        labels=dict(x=x_column, y=y_column),
                        title=f"Heatmap: {x_column} vs {y_column}")
    else:
        return st.warning("Please select a valid plot type.")

    return st.plotly_chart(fig)


def parsing_input(string):
    selected_option, x_column, y_column = string.split(",")
    return plot_visualization(selected_option, x_column, y_column)


zeroshot_tools = [
    Tool(
        name="answer_qa",
        func=csv_agnet,
        description="Use this tool to query the dataset. input to this tool should be a standalone question. Include the correct row titles that are needed. Example Input format: How many rows are there in the dataset, which name has the highest calories",
        #return_direct=True,
    ),
    Tool(
        name="create_simple_plot",
        func=parsing_input,
        description="""Use this tool if the user asks to create x vs y plots. input must be a comma seperated list of: selected_option, x_column, y_column
        Example Inputs: 
        bar,calories,name

        Allowed options are: bar, line, scatter_matrix, box, heatmap
        you can decide plot type, x colllumn and y collumn based on the user input.
        """,
        #return_direct=True,
    ),
    Tool(
        name="Calculator",
        func=math_tool,
        description="useful when you need to do calculations. Example input: 21^0.43"
    ),
]