Akshayram1 commited on
Commit
a422b23
1 Parent(s): 9a90d0f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -41
app.py CHANGED
@@ -9,67 +9,100 @@ from langchain.chains import RetrievalQA
9
  from langchain_groq import ChatGroq
10
  import joblib
11
  import os
 
12
 
13
  # Streamlit specific setup
14
  st.title("Uber Quarterly Report QA")
15
  user_llama_key = st.text_input("Enter your LLAMA Cloud API key:")
16
  user_groq_key = st.text_input("Enter your GROQ API key:")
17
 
18
- # LLAMAPARSE
19
- parser = LlamaParse(api_key=user_llama_key, result_type="markdown", verbose=True)
20
-
21
- # Function to load or parse data
22
- def load_or_parse_data():
23
  data_file = "./data/parsed_data.pkl"
24
- if os.path.exists(data_file):
25
- parsed_data = joblib.load(data_file)
26
- else:
27
- parsing_instruction = """The provided document is a quarterly report filed by Uber Technologies,
28
- Inc. with the Securities and Exchange Commission (SEC)...
29
- """
30
- parser = LlamaParse(api_key=user_llama_key, result_type="markdown", parsing_instruction=parsing_instruction, max_timeout=5000)
31
- llama_parse_documents = parser.load_data("./uber_10q_march_2022.pdf")
32
- joblib.dump(llama_parse_documents, data_file)
33
- parsed_data = llama_parse_documents
34
- return parsed_data
35
-
36
- # Create vector database
37
- def create_vector_database():
38
- llama_parse_documents = load_or_parse_data()
 
 
 
39
  with open('data/output.md', 'a') as f:
40
  for doc in llama_parse_documents:
41
  f.write(doc.text + '\n')
 
42
  markdown_path = "data/output.md"
43
  loader = UnstructuredMarkdownLoader(markdown_path)
44
  documents = loader.load()
 
 
45
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=2000, chunk_overlap=100)
46
  docs = text_splitter.split_documents(documents)
 
 
47
  embed_model = FastEmbedEmbeddings(model_name="BAAI/bge-base-en-v1.5")
 
 
48
  vs = Chroma.from_documents(
49
  documents=docs,
50
  embedding=embed_model,
51
  persist_directory="chroma_db_llamaparse1",
52
  collection_name="rag"
53
  )
54
- return vs, embed_model
55
-
56
- # Create QA model
57
- vs, embed_model = create_vector_database()
58
- chat_model = ChatGroq(temperature=0, model_name="mixtral-8x7b-32768", api_key=user_groq_key)
59
- vectorstore = Chroma(embedding_function=embed_model, persist_directory="chroma_db_llamaparse1", collection_name="rag")
60
- retriever = vectorstore.as_retriever(search_kwargs={'k': 3})
61
-
62
- # Define custom prompt template
63
- custom_prompt_template = """Use the following pieces of information to answer the user's question...
64
- """
65
- def set_custom_prompt():
 
 
 
 
 
 
 
66
  prompt = PromptTemplate(template=custom_prompt_template, input_variables=['context', 'question'])
67
- return prompt
68
-
69
- # Get user query and retrieve answer
70
- user_query = st.text_input("Ask a question:")
71
- if st.button("Get Answer"):
72
- prompt = set_custom_prompt()
73
- qa = RetrievalQA.from_chain_type(llm=chat_model, chain_type="stuff", retriever=retriever, return_source_documents=True, chain_type_kwargs={"prompt": prompt})
74
- response = qa.invoke({"query": user_query})
75
- st.write(response["result"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  from langchain_groq import ChatGroq
10
  import joblib
11
  import os
12
+ import tempfile
13
 
14
  # Streamlit specific setup
15
  st.title("Uber Quarterly Report QA")
16
  user_llama_key = st.text_input("Enter your LLAMA Cloud API key:")
17
  user_groq_key = st.text_input("Enter your GROQ API key:")
18
 
19
+ # Function to load or parse data from uploaded PDF file
20
+ def load_or_parse_data(uploaded_file):
 
 
 
21
  data_file = "./data/parsed_data.pkl"
22
+ with tempfile.NamedTemporaryFile(delete=False) as temp_file:
23
+ temp_file.write(uploaded_file.getvalue())
24
+ temp_file_path = temp_file.name
25
+ parsing_instruction = """The provided document is a quarterly report filed by Uber Technologies,
26
+ Inc. with the Securities and Exchange Commission (SEC)...
27
+ """
28
+ parser = LlamaParse(api_key=user_llama_key, result_type="markdown", parsing_instruction=parsing_instruction, max_timeout=5000)
29
+ llama_parse_documents = parser.load_data(temp_file_path)
30
+ os.remove(temp_file_path)
31
+ return llama_parse_documents
32
+
33
+ # User uploads PDF file
34
+ uploaded_file = st.file_uploader("Upload a PDF file", type="pdf")
35
+ if uploaded_file is not None:
36
+ llama_parse_documents = load_or_parse_data(uploaded_file)
37
+
38
+ # Further processing of the parsed data...
39
+ # Further processing of the parsed data
40
  with open('data/output.md', 'a') as f:
41
  for doc in llama_parse_documents:
42
  f.write(doc.text + '\n')
43
+
44
  markdown_path = "data/output.md"
45
  loader = UnstructuredMarkdownLoader(markdown_path)
46
  documents = loader.load()
47
+
48
+ # Split loaded documents into chunks
49
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=2000, chunk_overlap=100)
50
  docs = text_splitter.split_documents(documents)
51
+
52
+ # Initialize Embeddings
53
  embed_model = FastEmbedEmbeddings(model_name="BAAI/bge-base-en-v1.5")
54
+
55
+ # Create and persist a Chroma vector database from the chunked documents
56
  vs = Chroma.from_documents(
57
  documents=docs,
58
  embedding=embed_model,
59
  persist_directory="chroma_db_llamaparse1",
60
  collection_name="rag"
61
  )
62
+
63
+ # Initialize ChatGroq model
64
+ chat_model = ChatGroq(
65
+ temperature=0,
66
+ model_name="mixtral-8x7b-32768",
67
+ api_key=user_groq_key
68
+ )
69
+
70
+ # Convert retrieved documents into QA format
71
+ custom_prompt_template = """
72
+ Use the following pieces of information to answer the user's question.
73
+ If you don't know the answer, just say that you don't know, don't try to make up an answer.
74
+
75
+ Context: {context}
76
+ Question: {question}
77
+
78
+ Only return the helpful answer below and nothing else.
79
+ Helpful answer:
80
+ """
81
  prompt = PromptTemplate(template=custom_prompt_template, input_variables=['context', 'question'])
82
+
83
+ # Initialize RetrievalQA
84
+ qa = RetrievalQA.from_chain_type(
85
+ llm=chat_model,
86
+ chain_type="stuff",
87
+ retriever=vs.as_retriever(search_kwargs={'k': 3}),
88
+ return_source_documents=True,
89
+ chain_type_kwargs={"prompt": prompt}
90
+ )
91
+
92
+ # Define function to interactively ask questions and retrieve answers
93
+ def ask_question(question):
94
+ response = qa.invoke({"query": question})
95
+ return response["result"]
96
+
97
+ # Example questions
98
+ example_questions = [
99
+ "What is the Balance of UBER TECHNOLOGIES, INC. as of December 31, 2021?",
100
+ "What is the Cash flows from operating activities associated with bad expense specified in the document?",
101
+ "What is Loss (income) from equity method investments, net?"
102
+ ]
103
+
104
+ # Ask questions and display answers
105
+ for idx, question in enumerate(example_questions, start=1):
106
+ st.subheader(f"Question {idx}: {question}")
107
+ answer = ask_question(question)
108
+ st.write(f"Answer: {answer}")