0xrsydn commited on
Commit
5b12cd5
1 Parent(s): c5791c8

wrappped into class

Browse files
Files changed (1) hide show
  1. ai_agent.py +31 -28
ai_agent.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  import pandas as pd
3
  import tabulate
@@ -12,33 +13,35 @@ from langchain_core.prompts import ChatPromptTemplate
12
  from langchain_experimental.tools import PythonAstREPLTool
13
  from langchain_experimental.agents import create_pandas_dataframe_agent
14
 
15
- # Get the API key
16
- api_key = os.getenv('OPENAI_API_KEY')
17
- llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.5)
 
 
18
 
19
- def sql_agent(file, agent_input):
20
- # Read the uploaded file into a DataFrame
21
- df = pd.read_csv(file)
22
- # Create SQLAlchemy engine
23
- engine = create_engine("sqlite:///uploaded_data.db")
24
- # Write the DataFrame to the SQLite database
25
- df.to_sql("uploaded_data", engine, index=False, if_exists='replace')
26
- db = SQLDatabase(engine=engine)
27
- # Create SQL agent
28
- agent_executor = create_sql_agent(llm, db=db, agent_type="openai-tools", verbose=True)
29
- agent_output = agent_executor.invoke(agent_input)
30
- return agent_output
31
-
32
- def pandas_agent(file, agent_input):
33
- # Check if the file extension is CSV
34
- if file.name.endswith('.csv'):
35
  df = pd.read_csv(file)
36
- # Check if the file extension is XLSX or XLS
37
- elif file.name.endswith('.xlsx') or file.name.endswith('.xls'):
38
- df = pd.read_excel(file)
39
- else:
40
- return "Unsupported file format. Only CSV, XLS, or XLSX files are supported."
41
- # Proceed with your agent code
42
- agent = create_pandas_dataframe_agent(llm, df, agent_type="openai-tools", verbose=True)
43
- agent_output = agent.invoke(agent_input)
44
- return agent_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import getpass
2
  import os
3
  import pandas as pd
4
  import tabulate
 
13
  from langchain_experimental.tools import PythonAstREPLTool
14
  from langchain_experimental.agents import create_pandas_dataframe_agent
15
 
16
+ class AIAgent:
17
+ def __init__(self):
18
+ # Get the API key
19
+ self.api_key = os.getenv('OPENAI_API_KEY')
20
+ self.llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.5)
21
 
22
+ def sql_agent(self, file, agent_input):
23
+ # Read the uploaded file into a DataFrame
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  df = pd.read_csv(file)
25
+ # Create SQLAlchemy engine
26
+ engine = create_engine("sqlite:///uploaded_data.db")
27
+ # Write the DataFrame to the SQLite database
28
+ df.to_sql("uploaded_data", engine, index=False, if_exists='replace')
29
+ db = SQLDatabase(engine=engine)
30
+ # Create SQL agent
31
+ agent_executor = create_sql_agent(self.llm, db=db, agent_type="openai-tools", verbose=True)
32
+ agent_output = agent_executor.invoke(agent_input)
33
+ return agent_output
34
+
35
+ def pandas_agent(self, file, agent_input):
36
+ # Check if the file extension is CSV
37
+ if file.name.endswith('.csv'):
38
+ df = pd.read_csv(file)
39
+ # Check if the file extension is XLSX or XLS
40
+ elif file.name.endswith('.xlsx') or file.name.endswith('.xls'):
41
+ df = pd.read_excel(file)
42
+ else:
43
+ return "Unsupported file format. Only CSV, XLS, or XLSX files are supported."
44
+ # Proceed with your agent code
45
+ agent = create_pandas_dataframe_agent(self.llm, df, agent_type="openai-tools", verbose=True)
46
+ agent_output = agent.invoke(agent_input)
47
+ return agent_output