adityaagrawal commited on
Commit
7c0f544
1 Parent(s): 536d66f

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +200 -0
utils.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from langchain_openai import OpenAI
2
+ from langchain_experimental.agents import create_pandas_dataframe_agent
3
+ import pandas as pd
4
+ from dotenv import load_dotenv
5
+ import os
6
+ import streamlit as st
7
+ import json
8
+ from langchain_openai import ChatOpenAI
9
+ from langchain.text_splitter import CharacterTextSplitter
10
+ from langchain_openai import OpenAIEmbeddings
11
+ import weaviate
12
+ from langchain_community.vectorstores import Weaviate
13
+ from weaviate.embedded import EmbeddedOptions
14
+ from langchain.prompts import ChatPromptTemplate
15
+ from langchain.document_loaders.pdf import PyPDFLoader
16
+ from langchain.schema.runnable import RunnablePassthrough
17
+ from langchain.schema.output_parser import StrOutputParser
18
+ import gradio as gr
19
+ # from langchain_community.llms import ctransformers
20
+ # from ctransformers import AutoModelForCausalLM
21
+
22
+ load_dotenv()
23
+ API_KEY=os.getenv("OPENAI_API_KEY")
24
+ TEMP_DIR = "../temp"
25
+
26
+ # pdf
27
+ def agent(filename: str):
28
+
29
+ llm = ChatOpenAI(
30
+ model = "gpt-3.5-turbo-0125",
31
+ # model = "gpt-4",
32
+ temperature = 0.0,
33
+ # max_tokens = 256,
34
+ # top_p = 0.5,
35
+ )
36
+ df = pd.read_csv(filename, encoding='unicode_escape')
37
+ pandas_df_agent = create_pandas_dataframe_agent(llm, df, verbose=True)
38
+
39
+ return pandas_df_agent
40
+
41
+ def get_response(agent, query):
42
+ prompt = (
43
+ """
44
+ For the following query, if it requires drawing a table, reply as follows:
45
+ {"table": {"columns": ["column1", "column2", ...], "data": [[value1, value2, ...], [value1, value2, ...], ...]}}
46
+
47
+ If the query requires creating a bar chart, reply as follows:
48
+ {"bar": {"columns": ["A", "B", "C", ...], "data": [25, 24, 10, ...]}}
49
+
50
+ If the query requires creating a line chart, reply as follows:
51
+ {"line": {"columns": ["A", "B", "C", ...], "data": [25, 24, 10, ...]}}
52
+
53
+ There can only be two types of charts, "bar" and "line".
54
+
55
+ If it is just asking a question that requires neither, reply as follows:
56
+ {"answer": "answer"}
57
+ Example:
58
+ {"answer": "The product with the highest sales is 'Classic Cars.'"}
59
+
60
+ Write supportive numbers if there are any in the answer.
61
+ Example:
62
+ {"answer": "The product with the highest sales is 'Classic Cars' with 1111 sales."}
63
+
64
+ If you do not know the answer, reply as follows:
65
+ {"answer": "I do not know."}
66
+
67
+ Do not hallucinate or make up data. If the data is not available, reply "I do not know."
68
+
69
+ Return all output as a string in double quotes.
70
+
71
+ All strings in "columns" list and data list, should be in double quotes,
72
+
73
+ For example: {"columns": ["title", "ratings_count"], "data": [["Gilead", 361], ["Spider's Web", 5164]]}
74
+
75
+ Lets think step by step.
76
+
77
+ Below is the query.
78
+ Query:
79
+ """
80
+ + query
81
+ )
82
+
83
+ response = agent.run(prompt)
84
+ return response.__str__()
85
+
86
+ def return_response(response: str) -> dict:
87
+ try:
88
+ return json.loads(response)
89
+ except json.JSONDecodeError as e:
90
+ print(f"JSONDecodeError: {e}")
91
+ return None
92
+
93
+ def write_response(response_dict: dict):
94
+ if response_dict is not None:
95
+ if "answer" in response_dict:
96
+ answer = response_dict["answer"]
97
+ # st.write(answer)
98
+ return answer
99
+
100
+ if "bar" in response_dict:
101
+ data = response_dict["bar"]
102
+ df = pd.DataFrame.from_dict(data, orient = 'index')
103
+ df = df.transpose()
104
+ df.set_index("columns", inplace=True)
105
+ # st.bar_chart(df)
106
+ return gr.BarPlot(df)
107
+
108
+ if "line" in response_dict:
109
+ data = response_dict["line"]
110
+ df = pd.DataFrame(data)
111
+ df.set_index("columns", inplace=True)
112
+ # st.line_chart(df)
113
+ return gr.LinePlot(df)
114
+
115
+ # if "table" in response_dict:
116
+ # data = response_dict["table"]
117
+ # df = pd.DataFrame(data["data"], columns=data["columns"])
118
+ # # st.table(df)
119
+
120
+
121
+ else:
122
+ answer = "Decoded response is None. Please retry with a better prompt."
123
+ return (answer)
124
+
125
+ def ques_csv(data, question: str):
126
+ csv_agent = agent(data)
127
+ response = get_response(agent = csv_agent, query = question)
128
+ decoded_response = return_response(response)
129
+ answer = write_response(decoded_response)
130
+ return answer
131
+
132
+ # pdf
133
+ def ques_pdf(data, question: str):
134
+ doc = load_pdf(data)
135
+ chunks = split_pdf(doc)
136
+ retriever = store_retrieve(chunks)
137
+ prompt = write_prompt()
138
+ answer = ques_llm(retriever, prompt, question)
139
+ # st.write(answer)
140
+ return answer
141
+
142
+ def make_dir():
143
+ if not os.path.exists(TEMP_DIR):
144
+ os.makedirs(TEMP_DIR)
145
+
146
+ def upload(uploaded_file):
147
+ if uploaded_file is not None:
148
+ file_path = os.path.join(TEMP_DIR, uploaded_file.name)
149
+ with open(file_path, "wb") as f:
150
+ f.write(uploaded_file.getvalue())
151
+
152
+ return file_path
153
+
154
+ def load_pdf(filename: str):
155
+ loader = PyPDFLoader("{}".format(filename))
156
+ text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=20)
157
+ pages = loader.load_and_split(text_splitter = text_splitter)
158
+ return pages
159
+
160
+ def split_pdf(doc):
161
+ text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=50)
162
+ chunks = text_splitter.split_documents(doc)
163
+ return chunks
164
+
165
+ def store_retrieve(chunks):
166
+ client = weaviate.Client(
167
+ embedded_options = EmbeddedOptions()
168
+ )
169
+ vectorstore = Weaviate.from_documents(
170
+ client = client,
171
+ documents = chunks,
172
+ embedding = OpenAIEmbeddings(),
173
+ by_text = False
174
+ )
175
+ retriever = vectorstore.as_retriever()
176
+ return retriever
177
+
178
+ def write_prompt():
179
+ template = """You are an assistant for question-answering tasks.
180
+ Use the following pieces of retrieved context to answer the question.
181
+ If you don't know the answer, just say that you don't know.
182
+ Question: {question}
183
+ Context: {context}
184
+ Answer:
185
+ """
186
+ prompt = ChatPromptTemplate.from_template(template)
187
+ return prompt
188
+
189
+ def ques_llm(retriever, prompt, question):
190
+ llm = ChatOpenAI(model_name="gpt-4", temperature=0)
191
+ # # llm = AutoModelForCausalLM.from_pretrained("TheBloke/Llama-2-7B-Chat-GGML", model_file="llama-2-7b-chat.ggmlv3.q8_0.bin", temperature=0)
192
+ # llm = AutoModelForCausalLM.from_pretrained("TheBloke/Llama-2-7B-Chat-GGML", model_file="llama-2-7b-chat.ggmlv3.q4_0.bin", temperature=0)
193
+ rag_chain = (
194
+ {"context": retriever, "question": RunnablePassthrough()}
195
+ | prompt
196
+ | llm
197
+ | StrOutputParser()
198
+ )
199
+ ans = rag_chain.invoke(question)
200
+ return ans