Jan Mühlnikel commited on
Commit
71de22d
1 Parent(s): e3302f1

added semantic search function

Browse files
__pycache__/similarity_page.cpython-310.pyc CHANGED
Binary files a/__pycache__/similarity_page.cpython-310.pyc and b/__pycache__/similarity_page.cpython-310.pyc differ
 
functions/__pycache__/calc_matches.cpython-310.pyc CHANGED
Binary files a/functions/__pycache__/calc_matches.cpython-310.pyc and b/functions/__pycache__/calc_matches.cpython-310.pyc differ
 
functions/__pycache__/filter_projects.cpython-310.pyc CHANGED
Binary files a/functions/__pycache__/filter_projects.cpython-310.pyc and b/functions/__pycache__/filter_projects.cpython-310.pyc differ
 
functions/__pycache__/semantic_search.cpython-310.pyc ADDED
Binary file (1.07 kB). View file
 
functions/calc_matches.py CHANGED
@@ -1,7 +1,7 @@
1
  import pandas as pd
2
  import numpy as np
3
 
4
- def calc_matches(filtered_df, project_df, similarity_matrix):
5
  # matching project2 can be nay project
6
  # indecies (rows) = project1
7
  # columns = project2
@@ -14,7 +14,7 @@ def calc_matches(filtered_df, project_df, similarity_matrix):
14
  match_matrix = similarity_matrix[filtered_df_indecies_list]
15
 
16
  # get row (project1) and column (project2) with highest similarity in filtered df
17
- top_indices = np.unravel_index(np.argsort(match_matrix, axis=None)[-30:], match_matrix.shape)
18
 
19
  # get the corresponding similarity values
20
  top_values = match_matrix[top_indices]
 
1
  import pandas as pd
2
  import numpy as np
3
 
4
+ def calc_matches(filtered_df, project_df, similarity_matrix, top_x):
5
  # matching project2 can be nay project
6
  # indecies (rows) = project1
7
  # columns = project2
 
14
  match_matrix = similarity_matrix[filtered_df_indecies_list]
15
 
16
  # get row (project1) and column (project2) with highest similarity in filtered df
17
+ top_indices = np.unravel_index(np.argsort(match_matrix, axis=None)[-top_x:], match_matrix.shape)
18
 
19
  # get the corresponding similarity values
20
  top_values = match_matrix[top_indices]
functions/filter_projects.py CHANGED
@@ -1,12 +1,13 @@
1
  import pandas as pd
 
2
 
3
  def contains_code(crs_codes, code_list):
4
  codes = str(crs_codes).split(';')
5
  return any(code in code_list for code in codes)
6
 
7
- def filter_projects(df, crs3_list, crs5_list, sdg_str, country_code_list, orga_code_list):
8
  # Check if filters where not all should be selected are empty
9
- if crs3_list != [] or crs5_list != [] or sdg_str != "":
10
 
11
  # FILTER CRS
12
  if crs3_list and not crs5_list:
@@ -33,6 +34,13 @@ def filter_projects(df, crs3_list, crs5_list, sdg_str, country_code_list, orga_c
33
  if orga_code_list != []:
34
  df = df[df['orga_abbreviation'].isin(orga_code_list)]
35
 
 
 
 
 
 
 
 
36
 
37
  return df
38
 
 
1
  import pandas as pd
2
+ from functions.semantic_search import search
3
 
4
  def contains_code(crs_codes, code_list):
5
  codes = str(crs_codes).split(';')
6
  return any(code in code_list for code in codes)
7
 
8
+ def filter_projects(df, crs3_list, crs5_list, sdg_str, country_code_list, orga_code_list, query, model, embeddings, TOP_X_PROJECTS=30):
9
  # Check if filters where not all should be selected are empty
10
+ if crs3_list != [] or crs5_list != [] or sdg_str != "" or query != "":
11
 
12
  # FILTER CRS
13
  if crs3_list and not crs5_list:
 
34
  if orga_code_list != []:
35
  df = df[df['orga_abbreviation'].isin(orga_code_list)]
36
 
37
+ # FILTER QUERY
38
+ if query != "" and len(df) > 0:
39
+ if len(df) < TOP_X_PROJECTS:
40
+ TOP_X_PROJECTS = len(df)
41
+ df = search(query, model, embeddings, df, TOP_X_PROJECTS)
42
+
43
+
44
 
45
  return df
46
 
{modules → functions}/semantic_search.py RENAMED
@@ -2,22 +2,26 @@ import pickle
2
  import faiss
3
  import streamlit as st
4
  from sentence_transformers import SentenceTransformer
 
5
 
6
- def show_search(model, faiss_index, sentences):
7
- query = st.text_input("Enter your search query:")
 
 
 
 
 
 
 
8
 
9
- if query:
10
  # Convert query to embedding
11
  query_embedding = model.encode([query])[0].reshape(1, -1)
12
 
13
  # Perform search
14
- D, I = faiss_index.search(query_embedding, k=5) # Search for top 5 similar items
15
 
16
  # Extract the sentences corresponding to the top indices
17
- top_sentences = [sentences[i] for i in I[0]]
18
-
19
- # Display results as a selection list
20
- selected_sentence = st.selectbox("Top results:", top_sentences)
21
-
22
- # Optionally, do something with the selected sentence
23
- st.write("You selected:", selected_sentence)
 
2
  import faiss
3
  import streamlit as st
4
  from sentence_transformers import SentenceTransformer
5
+ import pandas as pd
6
 
7
+ def search(query, model, embeddings, filtered_df, top_x=30):
8
+
9
+ filtered_df_indecies_list = filtered_df.index
10
+ filtered_embeddings = embeddings[filtered_df_indecies_list]
11
+
12
+ # Load or create FAISS index
13
+ dimension = filtered_embeddings.shape[1]
14
+ faiss_index = faiss.IndexFlatL2(dimension)
15
+ faiss_index.add(filtered_embeddings)
16
 
 
17
  # Convert query to embedding
18
  query_embedding = model.encode([query])[0].reshape(1, -1)
19
 
20
  # Perform search
21
+ D, I = faiss_index.search(query_embedding, k=top_x) # Search for top x similar items
22
 
23
  # Extract the sentences corresponding to the top indices
24
+ #print(filtered_df.columns())
25
+ top_indecies = [i for i in I[0]]
26
+
27
+ return filtered_df.iloc[top_indecies]
 
 
 
modules/__pycache__/semantic_search.cpython-310.pyc CHANGED
Binary files a/modules/__pycache__/semantic_search.cpython-310.pyc and b/modules/__pycache__/semantic_search.cpython-310.pyc differ
 
similarity_page.py CHANGED
@@ -12,7 +12,6 @@ import pickle
12
  import faiss
13
  from sentence_transformers import SentenceTransformer
14
  from modules.result_table import show_table
15
- import modules.semantic_search as semantic_search
16
  from functions.filter_projects import filter_projects
17
  from functions.calc_matches import calc_matches
18
  import psutil
@@ -107,12 +106,7 @@ def load_embeddings_and_index():
107
  sentences = stored_data["sentences"]
108
  embeddings = stored_data["embeddings"]
109
 
110
- # Load or create FAISS index
111
- dimension = embeddings.shape[1]
112
- faiss_index = faiss.IndexFlatL2(dimension)
113
- faiss_index.add(embeddings)
114
-
115
- return sentences, embeddings, faiss_index
116
 
117
  # USE CACHE FUNCTIONS
118
  sim_matrix = load_sim_matrix()
@@ -124,8 +118,9 @@ SDG_NAMES = getSDG()
124
 
125
  COUNTRY_OPTION_LIST = getCountry()
126
 
 
127
  model = load_model()
128
- sentences, embeddings, faiss_index = load_embeddings_and_index()
129
 
130
  def show_page():
131
  st.write(f"Current RAM usage of this app: {get_process_memory():.2f} MB")
@@ -168,8 +163,6 @@ def show_page():
168
 
169
  with col2:
170
  # COUNTRY SELECTION
171
-
172
-
173
  country_option = st.multiselect(
174
  'Country / Countries',
175
  COUNTRY_OPTION_LIST,
@@ -186,6 +179,9 @@ def show_page():
186
  orga_list,
187
  placeholder="Select"
188
  )
 
 
 
189
 
190
 
191
  # CRS CODE LIST
@@ -205,10 +201,15 @@ def show_page():
205
  orga_code_list = [option.split("(")[1][:-1].lower() for option in orga_option]
206
 
207
  # FILTER DF WITH SELECTED FILTER OPTIONS
208
- filtered_df = filter_projects(projects_df, crs3_list, crs5_list, sdg_str, country_code_list, orga_code_list)
 
 
 
 
 
209
 
210
  # FIND MATCHES
211
- p1_df, p2_df = calc_matches(filtered_df, projects_df, sim_matrix)
212
 
213
  # SHOW THE RESULT
214
  show_table(p1_df, p2_df)
 
12
  import faiss
13
  from sentence_transformers import SentenceTransformer
14
  from modules.result_table import show_table
 
15
  from functions.filter_projects import filter_projects
16
  from functions.calc_matches import calc_matches
17
  import psutil
 
106
  sentences = stored_data["sentences"]
107
  embeddings = stored_data["embeddings"]
108
 
109
+ return sentences, embeddings
 
 
 
 
 
110
 
111
  # USE CACHE FUNCTIONS
112
  sim_matrix = load_sim_matrix()
 
118
 
119
  COUNTRY_OPTION_LIST = getCountry()
120
 
121
+ # LOAD MODEL FROM CACHE FO SEMANTIC SEARCH
122
  model = load_model()
123
+ sentences, embeddings = load_embeddings_and_index()
124
 
125
  def show_page():
126
  st.write(f"Current RAM usage of this app: {get_process_memory():.2f} MB")
 
163
 
164
  with col2:
165
  # COUNTRY SELECTION
 
 
166
  country_option = st.multiselect(
167
  'Country / Countries',
168
  COUNTRY_OPTION_LIST,
 
179
  orga_list,
180
  placeholder="Select"
181
  )
182
+
183
+ # SEARCH BOX
184
+ query = st.text_input("Enter your search query:")
185
 
186
 
187
  # CRS CODE LIST
 
201
  orga_code_list = [option.split("(")[1][:-1].lower() for option in orga_option]
202
 
203
  # FILTER DF WITH SELECTED FILTER OPTIONS
204
+
205
+ TOP_X_PROJECTS = 30
206
+ filtered_df = filter_projects(projects_df, crs3_list, crs5_list, sdg_str, country_code_list, orga_code_list, query, model, embeddings, TOP_X_PROJECTS)
207
+ #with col2:
208
+ # Semantic Search
209
+ #searched_filtered_df = semantic_search.show_search(model, embeddings, sentences, filtered_df, TOP_X_PROJECTS)
210
 
211
  # FIND MATCHES
212
+ p1_df, p2_df = calc_matches(filtered_df, projects_df, sim_matrix, TOP_X_PROJECTS)
213
 
214
  # SHOW THE RESULT
215
  show_table(p1_df, p2_df)