JMuscatello commited on
Commit
38e58b7
β€’
1 Parent(s): 647ce10

Add ability to upload docs

Browse files
Files changed (1) hide show
  1. pages/6_πŸ”Ž_Find_Demo.py +104 -41
pages/6_πŸ”Ž_Find_Demo.py CHANGED
@@ -1,9 +1,14 @@
1
  import os
 
 
2
 
3
  import pandas as pd
4
 
5
  import streamlit as st
6
  import streamlit_analytics
 
 
 
7
  from utils import add_logo_to_sidebar, add_footer, add_email_signup_form
8
 
9
  from huggingface_hub import snapshot_download
@@ -14,9 +19,9 @@ from haystack.nodes import BM25Retriever, EmbeddingRetriever
14
  HF_TOKEN = os.environ.get("HF_TOKEN")
15
  DATA_REPO_ID = "simplexico/cuad-qa-answers"
16
  DATA_FILENAME = "cuad_questions_answers.json"
17
- EMBEDDING_MODEL = "prajjwal1/bert-tiny"
18
- if EMBEDDING_MODEL == "prajjwal1/bert-tiny":
19
- EMBEDDING_DIM = 128
20
  else:
21
  EMBEDDING_DIM = 768
22
 
@@ -31,7 +36,7 @@ def load_dataset():
31
  return df
32
 
33
  @st.cache(allow_output_mutation=True)
34
- def generate_document_store(df, dummy=None):
35
  """Create haystack document store using contract clause data
36
  """
37
  document_dicts = []
@@ -39,24 +44,44 @@ def generate_document_store(df, dummy=None):
39
  for idx, row in df.iterrows():
40
  document_dicts.append(
41
  {
42
- 'content': row['answer_text'],
43
- 'meta': {'contract_title': row['contract_title'], 'question_id': row['question_id']}
44
  }
45
  )
46
 
47
- document_store = InMemoryDocumentStore(use_bm25=True, embedding_dim=EMBEDDING_DIM)
48
 
49
  document_store.write_documents(document_dicts)
50
 
51
  return document_store
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  @st.cache(allow_output_mutation=True)
54
  def generate_bm25_retriever(document_store):
55
  return BM25Retriever(document_store)
56
 
57
  @st.cache(allow_output_mutation=True)
58
  def generate_embeddings(embedding_model, document_store):
59
- embedding_retriever = EmbeddingRetriever(embedding_model=embedding_model, document_store=document_store)
 
 
 
 
 
60
  document_store.update_embeddings(embedding_retriever)
61
  return embedding_retriever
62
 
@@ -64,16 +89,27 @@ def process_query(query, retriever):
64
  """Generates dataframe with top ten results"""
65
  texts = []
66
  contract_titles = []
 
 
67
  candidate_documents = retriever.retrieve(
68
  query=query,
69
  top_k=10,
70
  )
71
 
72
- for document in candidate_documents:
73
  texts.append(document.content)
74
  contract_titles.append(document.meta["contract_title"])
 
 
75
 
76
- return pd.DataFrame({"Text": texts, "Source Contract": contract_titles})
 
 
 
 
 
 
 
77
 
78
  st.set_page_config(
79
  page_title="Find Demo",
@@ -93,50 +129,77 @@ st.sidebar.success("πŸ‘† Select a demo above.")
93
  st.title('πŸ”Ž Find Demo')
94
 
95
  st.write("""
96
- This demo shows how a set of documents can be searched.
97
- We've set up a database of clauses from a set of open source legal documents.
98
- These clauses can be searched using **keywords** or using **semantic search**.
99
  Semantic search leverages an AI model which matches on clauses with a similar meaning to the input text.
100
  """)
101
- st.write("**πŸ‘ˆ Enter search query on the left** and hit the button **Find Clauses** to see the demo in action")
102
 
103
- query = st.sidebar.text_area(label='Enter Searcb Query', value=EXAMPLE_TEXT, height=250)
104
- button = st.sidebar.button('**Find Clauses**', type='primary', use_container_width=True)
105
 
106
- df = load_dataset()
107
 
108
- document_store = generate_document_store(df)
109
 
110
- bm25_retriever = generate_bm25_retriever(document_store)
111
 
112
- embedding_retriever = generate_embeddings(EMBEDDING_MODEL, document_store)
 
113
 
114
- if button:
115
-
116
- hide_dataframe_row_index = """
117
- <style>
118
- .row_heading.level0 {display:none}
119
- .blank {display:none}
120
- </style>
121
- """
122
-
123
- col1, col2 = st.columns(2)
124
 
 
125
  with col1:
126
-
127
- st.subheader('Keyword Search Results:')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  # Inject CSS with Markdown
129
  st.markdown(hide_dataframe_row_index, unsafe_allow_html=True)
130
- df_bm25 = process_query(query, bm25_retriever)
131
- st.table(df_bm25)
132
-
133
- with col2:
134
 
135
- st.subheader('Semantic Search Results:')
136
- # Inject CSS with Markdown
137
- st.markdown(hide_dataframe_row_index, unsafe_allow_html=True)
138
- df_embed = process_query(query, embedding_retriever)
139
- st.table(df_embed)
 
 
 
 
 
 
 
 
 
 
 
140
 
141
  add_email_signup_form()
142
 
 
1
  import os
2
+ from io import StringIO
3
+ import re
4
 
5
  import pandas as pd
6
 
7
  import streamlit as st
8
  import streamlit_analytics
9
+
10
+ import streamlit_toggle as tog
11
+
12
  from utils import add_logo_to_sidebar, add_footer, add_email_signup_form
13
 
14
  from huggingface_hub import snapshot_download
 
19
  HF_TOKEN = os.environ.get("HF_TOKEN")
20
  DATA_REPO_ID = "simplexico/cuad-qa-answers"
21
  DATA_FILENAME = "cuad_questions_answers.json"
22
+ EMBEDDING_MODEL = "sentence-transformers/paraphrase-MiniLM-L3-v2"
23
+ if EMBEDDING_MODEL == "sentence-transformers/multi-qa-MiniLM-L6-cos-v1" or EMBEDDING_MODEL == "sentence-transformers/paraphrase-MiniLM-L3-v2":
24
+ EMBEDDING_DIM = 384
25
  else:
26
  EMBEDDING_DIM = 768
27
 
 
36
  return df
37
 
38
  @st.cache(allow_output_mutation=True)
39
+ def generate_document_store(df):
40
  """Create haystack document store using contract clause data
41
  """
42
  document_dicts = []
 
44
  for idx, row in df.iterrows():
45
  document_dicts.append(
46
  {
47
+ 'content': row['paragraph'],
48
+ 'meta': {'contract_title': row['contract_title']}
49
  }
50
  )
51
 
52
+ document_store = InMemoryDocumentStore(use_bm25=True, embedding_dim=EMBEDDING_DIM, similarity='cosine')
53
 
54
  document_store.write_documents(document_dicts)
55
 
56
  return document_store
57
 
58
+ def files_to_dataframe(uploaded_files, limit=10):
59
+ texts = []
60
+ titles = []
61
+ for uploaded_file in uploaded_files[:limit]:
62
+
63
+ stringio = StringIO(uploaded_file.getvalue().decode("utf-8"))
64
+
65
+ text = stringio.read().strip()
66
+ paragraphs = text.split("\n\n")
67
+ paragraphs = [p.strip() for p in paragraphs if len(p.split()) > 10]
68
+ texts.extend(paragraphs)
69
+ titles.extend([uploaded_file.name]*len(paragraphs))
70
+
71
+ return pd.DataFrame({'paragraph': texts, 'contract_title': titles})
72
+
73
  @st.cache(allow_output_mutation=True)
74
  def generate_bm25_retriever(document_store):
75
  return BM25Retriever(document_store)
76
 
77
  @st.cache(allow_output_mutation=True)
78
  def generate_embeddings(embedding_model, document_store):
79
+ embedding_retriever = EmbeddingRetriever(
80
+ embedding_model=embedding_model,
81
+ document_store=document_store,
82
+ model_format="sentence_transformers",
83
+ scale_score=True
84
+ )
85
  document_store.update_embeddings(embedding_retriever)
86
  return embedding_retriever
87
 
 
89
  """Generates dataframe with top ten results"""
90
  texts = []
91
  contract_titles = []
92
+ scores = []
93
+ ranking = []
94
  candidate_documents = retriever.retrieve(
95
  query=query,
96
  top_k=10,
97
  )
98
 
99
+ for idx, document in enumerate(candidate_documents):
100
  texts.append(document.content)
101
  contract_titles.append(document.meta["contract_title"])
102
+ scores.append(str(round(document.score, 2)))
103
+ ranking.append(idx + 1)
104
 
105
+ return pd.DataFrame(
106
+ {
107
+ "Ranking": ranking,
108
+ "Text": texts,
109
+ "Source Contract": contract_titles,
110
+ "Similarity": scores
111
+ }
112
+ )
113
 
114
  st.set_page_config(
115
  page_title="Find Demo",
 
129
  st.title('πŸ”Ž Find Demo')
130
 
131
  st.write("""
132
+ This demo shows how a set of clauses can be searched.
133
+ Upload a set of contracts on the left and the paragraphs can be searched using **keywords** or using **semantic search**.
 
134
  Semantic search leverages an AI model which matches on clauses with a similar meaning to the input text.
135
  """)
136
+ st.write("**πŸ‘ˆ Upload a set of contracts on the left** to start the demo")
137
 
 
 
138
 
139
+ #df = load_dataset()
140
 
141
+ #document_store = generate_document_store(df)
142
 
143
+ #bm25_retriever = generate_bm25_retriever(document_store)
144
 
145
+ #embedding_retriever = generate_embeddings(EMBEDDING_MODEL, document_store)
146
+ col1, col2, col3, col4, col5 = st.columns(5)
147
 
148
+ uploaded_files = st.sidebar.file_uploader("Select contracts to search **(upload up to 10 files)**", accept_multiple_files=True)
 
 
 
 
 
 
 
 
 
149
 
150
+ if uploaded_files:
151
  with col1:
152
+ st.write("Toggle between **keyword** or **semantic** search:")
153
+ value = tog.st_toggle_switch(
154
+ label="Keyword/Semantic",
155
+ label_after=True,
156
+ inactive_color='#D3D3D3',
157
+ active_color="#11567f",
158
+ track_color="#29B5E8"
159
+ )
160
+ if value:
161
+ search_type = "semantic"
162
+ else:
163
+ search_type = "keyword"
164
+
165
+ print(value)
166
+
167
+ df = files_to_dataframe(uploaded_files)
168
+ document_store = generate_document_store(df)
169
+ bm25_retriever = generate_bm25_retriever(document_store)
170
+ st.write("**πŸ‘‡ Enter search query below** and hit the button **Find Clauses** to see the demo in action")
171
+ query = st.text_area(label='Enter Search Query', value=EXAMPLE_TEXT, height=250)
172
+ button = st.button('**Find Clauses**', type='primary', use_container_width=True)
173
+
174
+ if button:
175
+
176
+ hide_dataframe_row_index = """
177
+ <style>
178
+ .row_heading.level0 {display:none}
179
+ .blank {display:none}
180
+ </style>
181
+ """
182
+
183
+ st.subheader(f'Search Results ({search_type}):')
184
  # Inject CSS with Markdown
185
  st.markdown(hide_dataframe_row_index, unsafe_allow_html=True)
 
 
 
 
186
 
187
+ if search_type == "keyword":
188
+ df_bm25 = process_query(query, bm25_retriever)
189
+ st.table(df_bm25)
190
+
191
+ if search_type == "semantic":
192
+ embedding_retriever = generate_embeddings(EMBEDDING_MODEL, document_store)
193
+ df_embed = process_query(query, embedding_retriever)
194
+ st.table(df_embed)
195
+
196
+ # with col2:
197
+
198
+ # st.subheader('Semantic Search Results:')
199
+ # # Inject CSS with Markdown
200
+ # st.markdown(hide_dataframe_row_index, unsafe_allow_html=True)
201
+ # df_embed = process_query(query, embedding_retriever)
202
+ # st.table(df_embed)
203
 
204
  add_email_signup_form()
205