ababio commited on
Commit
5db7189
1 Parent(s): 9f011a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -35
app.py CHANGED
@@ -1,9 +1,6 @@
1
- # Streamlit application
2
- import streamlit as st
3
  import os
4
  from getpass import getpass
5
- from transformers import pipeline
6
-
7
  from llama_index.node_parser import SemanticSplitterNodeParser
8
  from llama_index.embeddings import OpenAIEmbedding
9
  from llama_index.ingestion import IngestionPipeline
@@ -14,12 +11,17 @@ from llama_index import VectorStoreIndex
14
  from llama_index.retrievers import VectorIndexRetriever
15
  from llama_index.query_engine import RetrieverQueryEngine
16
 
17
- # Function to initialize the Pinecone and LlamaIndex setup
18
- def initialize_pipeline():
19
- pinecone_api_key = os.getenv("PINECONE_API_KEY")
20
- openai_api_key = os.getenv("OPENAI_API_KEY")
 
21
 
 
 
22
  embed_model = OpenAIEmbedding(api_key=openai_api_key)
 
 
23
  pipeline = IngestionPipeline(
24
  transformations=[
25
  SemanticSplitterNodeParser(
@@ -31,41 +33,32 @@ def initialize_pipeline():
31
  ],
32
  )
33
 
 
34
  pc = PineconeGRPC(api_key=pinecone_api_key)
35
  index_name = "anualreport"
 
 
36
  pinecone_index = pc.Index(index_name)
 
 
37
  vector_store = PineconeVectorStore(pinecone_index=pinecone_index)
38
- pinecone_index.describe_index_stats()
39
 
40
- if not os.getenv('OPENAI_API_KEY'):
41
- os.environ['OPENAI_API_KEY'] = openai_api_key
42
 
 
43
  vector_index = VectorStoreIndex.from_vector_store(vector_store=vector_store)
44
- retriever = VectorIndexRetriever(index=vector_index, similarity_top_k=5)
45
- query_engine = RetrieverQueryEngine(retriever=retriever)
46
-
47
- return query_engine
48
-
49
- # Streamlit UI
50
- st.title("Chat with Annual Reports")
51
 
52
- # Initialize the query engine
53
- query_engine = initialize_pipeline()
54
-
55
- # Conversation model using Hugging Face transformers
56
- conversation_pipeline = pipeline("conversational", model="microsoft/DialoGPT-medium")
57
-
58
- # User input
59
- user_input = st.text_input("You: ", "")
60
 
61
- if user_input:
62
- # Query the vector DB
63
- llm_query = query_engine.query(user_input)
64
- response = llm_query.response
65
 
66
- # Generate response using Hugging Face conversation model
67
- conversation = conversation_pipeline([user_input, response])
68
- bot_response = conversation[-1]["generated_text"]
 
69
 
70
- # Display response
71
- st.text_area("Bot: ", bot_response, height=200)
 
 
 
 
1
  import os
2
  from getpass import getpass
3
+ import streamlit as st
 
4
  from llama_index.node_parser import SemanticSplitterNodeParser
5
  from llama_index.embeddings import OpenAIEmbedding
6
  from llama_index.ingestion import IngestionPipeline
 
11
  from llama_index.retrievers import VectorIndexRetriever
12
  from llama_index.query_engine import RetrieverQueryEngine
13
 
14
+ # Streamlit interface for inputting API keys
15
+ st.title("Annual Report Summary Chatbot")
16
+ pinecone_api_key = os.getenv("PINECONE_API_KEY")
17
+ openai_api_key = os.getenv("OPENAI_API_KEY")
18
+
19
 
20
+ if pinecone_api_key and openai_api_key:
21
+ # This will be the model we use both for Node parsing and for vectorization
22
  embed_model = OpenAIEmbedding(api_key=openai_api_key)
23
+
24
+ # Define the initial pipeline
25
  pipeline = IngestionPipeline(
26
  transformations=[
27
  SemanticSplitterNodeParser(
 
33
  ],
34
  )
35
 
36
+ # Initialize connection to Pinecone
37
  pc = PineconeGRPC(api_key=pinecone_api_key)
38
  index_name = "anualreport"
39
+
40
+ # Initialize your index
41
  pinecone_index = pc.Index(index_name)
42
+
43
+ # Initialize VectorStore
44
  vector_store = PineconeVectorStore(pinecone_index=pinecone_index)
 
45
 
46
+ pinecone_index.describe_index_stats()
 
47
 
48
+ # Instantiate VectorStoreIndex object from our vector_store object
49
  vector_index = VectorStoreIndex.from_vector_store(vector_store=vector_store)
 
 
 
 
 
 
 
50
 
51
+ # Grab 5 search results
52
+ retriever = VectorIndexRetriever(index=vector_index, similarity_top_k=5)
 
 
 
 
 
 
53
 
54
+ # Query vector DB
55
+ query_engine = RetrieverQueryEngine(retriever=retriever)
 
 
56
 
57
+ user_query = st.text_input("Ask a question about the annual report:")
58
+ if st.button("Submit"):
59
+ llm_query = query_engine.query(user_query)
60
+ st.write(llm_query.response)
61
 
62
+ # Run the app
63
+ if __name__ == "__main__":
64
+ st.run()