NEXAS's picture
Update main.py
289a8e7 verified
raw
history blame
2.6 kB
from langchain_groq import ChatGroq
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.agents import Tool
from langchain_experimental.utilities import PythonREPL # type: ignore
from langchain_community.chat_models import ChatOllama
import autopep8 # type: ignore
import pandas as pd
import os
from dotenv import load_dotenv
load_dotenv()
groq_api_key = os.getenv("GROQ_API_KEY")
class datachat():
def __init__(self,file_path):
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
self.llm = ChatGroq(temperature=0, model_name="mixtral-8x7b-32768",callback_manager=callback_manager)
self.instruction = """
As a python coder create a pythonic response for the query with reference to the columns in my pandas dataframe{columns}.
Instruction:
Do not write the whole script just give me a pythonic response for this query and do not extend more than asked. Assume a dataframe variable df_temp.
Enclose the generated code in Markdown code embedding format. Do not generate sample output. Answer the question and provide a one-line explanation and stop.
example:
```python
output = df['region'].unique()
```
question: {input}
answer:
"""
self.file_path=file_path
def extract_code(self,response):
start = 0
q = ""
temp_block=""
for line in response.splitlines():
if '```python' in line and start==0:
start=1
if '```' == line.strip() and start==1:
start =0
break
if start ==1 and '```' not in line:
q=q+'\n'+line
return q
def data_ops(self,query):
if os.path.isfile('./output.csv'):
df=pd.read_csv('./output.csv')
else:
df=pd.read_csv(self.file_path)
query = query
columns=df.columns.tolist()
prompt = PromptTemplate.from_template(self.instruction)
agent = LLMChain(llm=self.llm,prompt=prompt)
response = agent.invoke(input={"columns":columns,"input":query})
response = self.extract_code(response['text'])
gencode=autopep8.fix_code(response)
df_temp=df
exec(gencode)
df_temp.to_csv('./output.csv',index=False)
return df_temp