vinhnx90 commited on
Commit
d1e97a4
β€’
1 Parent(s): d036875

Update app

Browse files
Files changed (2) hide show
  1. app.py +13 -15
  2. document_retriever.py +11 -5
app.py CHANGED
@@ -47,23 +47,19 @@ with st.sidebar:
47
  with col2:
48
  st.header(":books: InkChatGPT")
49
 
50
- # chat_tab,
51
- documents_tab, settings_tab = st.tabs(
52
- [
53
- # "Chat",
54
- "Documents",
55
- "Settings",
56
- ]
57
- )
58
  with settings_tab:
59
  openai_api_key = st.text_input("OpenAI API Key", type="password")
 
 
 
 
 
 
 
60
  if len(msgs.messages) == 0 or st.button("Clear message history"):
61
  msgs.clear()
62
- msgs.add_ai_message("""
63
- Hi, your uploaded document(s) had been analyzed.
64
-
65
- Feel free to ask me any questions. For example: you can start by asking me `'What is this book about?` or `Tell me about the content of this book!`'
66
- """)
67
 
68
  with documents_tab:
69
  uploaded_files = st.file_uploader(
@@ -74,10 +70,12 @@ with st.sidebar:
74
  )
75
 
76
  if not openai_api_key:
77
- st.info("πŸ”‘ Please Add your **OpenAI API key** on the `Settings` to continue.")
78
 
79
  if uploaded_files:
80
- result_retriever = configure_retriever(uploaded_files)
 
 
81
 
82
  if result_retriever is not None:
83
  memory = ConversationBufferMemory(
 
47
  with col2:
48
  st.header(":books: InkChatGPT")
49
 
50
+ documents_tab, settings_tab = st.tabs(["Documents", "Settings"])
 
 
 
 
 
 
 
51
  with settings_tab:
52
  openai_api_key = st.text_input("OpenAI API Key", type="password")
53
+
54
+ cohere_api_key = ""
55
+ if st.toggle(
56
+ label="Use Cohere's Rerank", help="https://txt.cohere.com/rerank/"
57
+ ):
58
+ cohere_api_key = st.text_input("Cohere API Key", type="password")
59
+
60
  if len(msgs.messages) == 0 or st.button("Clear message history"):
61
  msgs.clear()
62
+ msgs.add_ai_message("Hello, how can I help you?")
 
 
 
 
63
 
64
  with documents_tab:
65
  uploaded_files = st.file_uploader(
 
70
  )
71
 
72
  if not openai_api_key:
73
+ st.info("πŸ”‘ Please open the `Settings` tab from side bar menu to get started.")
74
 
75
  if uploaded_files:
76
+ result_retriever = configure_retriever(
77
+ uploaded_files, cohere_api_key=cohere_api_key
78
+ )
79
 
80
  if result_retriever is not None:
81
  memory = ConversationBufferMemory(
document_retriever.py CHANGED
@@ -3,7 +3,7 @@ import tempfile
3
 
4
  import streamlit as st
5
  from langchain.retrievers import ContextualCompressionRetriever
6
-
7
  from langchain_cohere import CohereRerank
8
  from langchain_community.document_loaders import Docx2txtLoader, PyPDFLoader, TextLoader
9
  from langchain_community.embeddings import HuggingFaceEmbeddings
@@ -11,10 +11,11 @@ from langchain_community.vectorstores import DocArrayInMemorySearch
11
  from langchain_text_splitters import RecursiveCharacterTextSplitter
12
 
13
  EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
 
14
 
15
 
16
  @st.cache_resource(ttl="1h")
17
- def configure_retriever(files, use_compression=False):
18
  # Read documents
19
  docs = []
20
  temp_dir = tempfile.TemporaryDirectory()
@@ -54,8 +55,13 @@ def configure_retriever(files, use_compression=False):
54
  if not use_compression:
55
  return retriever
56
 
57
- compressor = CohereRerank()
 
 
 
 
 
 
58
  return ContextualCompressionRetriever(
59
- base_compressor=compressor,
60
- base_retriever=retriever,
61
  )
 
3
 
4
  import streamlit as st
5
  from langchain.retrievers import ContextualCompressionRetriever
6
+ from langchain.retrievers.document_compressors import EmbeddingsFilter
7
  from langchain_cohere import CohereRerank
8
  from langchain_community.document_loaders import Docx2txtLoader, PyPDFLoader, TextLoader
9
  from langchain_community.embeddings import HuggingFaceEmbeddings
 
11
  from langchain_text_splitters import RecursiveCharacterTextSplitter
12
 
13
  EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
14
+ RERANK_MODEL = "rerank-english-v2.0"
15
 
16
 
17
  @st.cache_resource(ttl="1h")
18
+ def configure_retriever(files, cohere_api_key, use_compression=False):
19
  # Read documents
20
  docs = []
21
  temp_dir = tempfile.TemporaryDirectory()
 
55
  if not use_compression:
56
  return retriever
57
 
58
+ if cohere_api_key.len() == 0:
59
+ compressor = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.76)
60
+ else:
61
+ compressor = CohereRerank(
62
+ top_n=3, model=RERANK_MODEL, cohere_api_key=cohere_api_key
63
+ )
64
+
65
  return ContextualCompressionRetriever(
66
+ base_compressor=compressor, base_retriever=retriever
 
67
  )