Jan Mühlnikel
added extended similarity score and search query optimization
4135c81
raw
history blame
No virus
6.61 kB
"""
Page for similarities
"""
################
# DEPENDENCIES #
################
import streamlit as st
import pandas as pd
from scipy.sparse import load_npz
import pickle
import faiss
from sentence_transformers import SentenceTransformer
from modules.result_table import show_table
from functions.filter_projects import filter_projects
from functions.calc_matches import calc_matches
import psutil
import os
import gc
def get_process_memory():
process = psutil.Process(os.getpid())
return process.memory_info().rss / (1024 * 1024)
# Catch DATA
# Load Similarity matrix
@st.cache_data
def load_sim_matrix():
loaded_matrix = load_npz("src/extended_similarities.npz")
dense_matrix = loaded_matrix.toarray()
return dense_matrix
# Load Projects DFs
@st.cache_data
def load_projects():
orgas_df = pd.read_csv("src/projects/project_orgas.csv")
region_df = pd.read_csv("src/projects/project_region.csv")
sector_df = pd.read_csv("src/projects/project_sector.csv")
status_df = pd.read_csv("src/projects/project_status.csv")
texts_df = pd.read_csv("src/projects/project_texts.csv")
projects_df = pd.merge(orgas_df, region_df, on='iati_id', how='inner')
projects_df = pd.merge(projects_df, sector_df, on='iati_id', how='inner')
projects_df = pd.merge(projects_df, status_df, on='iati_id', how='inner')
projects_df = pd.merge(projects_df, texts_df, on='iati_id', how='inner')
return projects_df
# Load CRS 3 data
@st.cache_data
def getCRS3():
# Read in CRS3 CODELISTS
crs3_df = pd.read_csv('src/codelists/crs3_codes.csv')
CRS3_CODES = crs3_df['code'].tolist()
CRS3_NAME = crs3_df['name'].tolist()
CRS3_MERGED = {f"{name} - {code}": code for name, code in zip(CRS3_NAME, CRS3_CODES)}
return CRS3_MERGED
# Load CRS 5 data
@st.cache_data
def getCRS5():
# Read in CRS3 CODELISTS
crs5_df = pd.read_csv('src/codelists/crs5_codes.csv')
CRS5_CODES = crs5_df['code'].tolist()
CRS5_NAME = crs5_df['name'].tolist()
CRS5_MERGED = {code: [f"{name} - {code}"] for name, code in zip(CRS5_NAME, CRS5_CODES)}
return CRS5_MERGED
# Load SDG data
@st.cache_data
def getSDG():
# Read in SDG CODELISTS
sdg_df = pd.read_csv('src/codelists/sdg_goals.csv')
SDG_NAMES = sdg_df['name'].tolist()
return SDG_NAMES
# Load Country Data
@st.cache_data
def getCountry():
# Read in countries from codelist
country_df = pd.read_csv('src/codelists/country_codes_ISO3166-1alpha-2.csv')
COUNTRY_CODES = country_df['Alpha-2 code'].tolist()
COUNTRY_NAMES = country_df['Country'].tolist()
COUNTRY_OPTION_LIST = [f"{COUNTRY_NAMES[i]} ({COUNTRY_CODES[i][-3:-1].upper()})"for i in range(len(COUNTRY_NAMES))]
return COUNTRY_OPTION_LIST
# Load Sentence Transformer Model
@st.cache_resource
def load_model():
model = SentenceTransformer('all-MiniLM-L6-v2')
return model
# Load Embeddings
@st.cache_data
def load_embeddings_and_index():
# Load embeddings
with open("src/embeddings.pkl", "rb") as fIn:
stored_data = pickle.load(fIn)
sentences = stored_data["sentences"]
embeddings = stored_data["embeddings"]
return sentences, embeddings
# USE CACHE FUNCTIONS
sim_matrix = load_sim_matrix()
projects_df = load_projects()
CRS3_MERGED = getCRS3()
CRS5_MERGED = getCRS5()
SDG_NAMES = getSDG()
COUNTRY_OPTION_LIST = getCountry()
# LOAD MODEL FROM CACHE FO SEMANTIC SEARCH
model = load_model()
sentences, embeddings = load_embeddings_and_index()
def show_page():
st.write(f"Current RAM usage of this app: {get_process_memory():.2f} MB")
st.write("Similarities")
st.session_state.crs5_option_disabled = True
col1, col2 = st.columns([1, 1])
with col1:
# CRS 3 SELECTION
crs3_option = st.multiselect(
'CRS 3',
CRS3_MERGED,
placeholder="Select"
)
# CRS 5 SELECTION
## Only enable crs5 select field when crs3 code is selected
if crs3_option != []:
st.session_state.crs5_option_disabled = False
## define list of crs5 codes dependend on crs3 codes
crs5_list = [txt[0].replace('"', "") for crs3_item in crs3_option for code, txt in CRS5_MERGED.items() if str(code)[:3] == str(crs3_item)[-3:]]
## crs5 select field
crs5_option = st.multiselect(
'CRS 5',
crs5_list,
placeholder="Select",
disabled=st.session_state.crs5_option_disabled
)
# SDG SELECTION
sdg_option = st.selectbox(
label = 'SDG',
index = None,
placeholder = "Select SDG",
options = SDG_NAMES[:-1],
)
with col2:
# COUNTRY SELECTION
country_option = st.multiselect(
'Country / Countries',
COUNTRY_OPTION_LIST,
placeholder="Select"
)
# ORGA SELECTION
orga_abbreviation = projects_df["orga_abbreviation"].unique()
orga_full_names = projects_df["orga_full_name"].unique()
orga_list = [f"{orga_full_names[i]} ({orga_abbreviation[i].upper()})"for i in range(len(orga_abbreviation))]
orga_option = st.multiselect(
'Development Bank / Organization',
orga_list,
placeholder="Select"
)
# SEARCH BOX
query = st.text_input("Enter your search query:")
# CRS CODE LIST
crs3_list = [i[-3:] for i in crs3_option]
crs5_list = [i[-5:] for i in crs5_option]
# SDG CODE LIST
if sdg_option != None:
sdg_str = sdg_option[0]
else:
sdg_str = ""
# COUNTRY CODES LIST
country_code_list = [option[-3:-1] for option in country_option]
# ORGANIZATION CODES LIST
orga_code_list = [option.split("(")[1][:-1].lower() for option in orga_option]
# FILTER DF WITH SELECTED FILTER OPTIONS
TOP_X_PROJECTS = 30
filtered_df = filter_projects(projects_df, crs3_list, crs5_list, sdg_str, country_code_list, orga_code_list, query, model, embeddings, TOP_X_PROJECTS)
#with col2:
# Semantic Search
#searched_filtered_df = semantic_search.show_search(model, embeddings, sentences, filtered_df, TOP_X_PROJECTS)
# FIND MATCHES
p1_df, p2_df = calc_matches(filtered_df, projects_df, sim_matrix, TOP_X_PROJECTS)
# SHOW THE RESULT
show_table(p1_df, p2_df)
del p1_df, p2_df, crs3_list, crs5_list, sdg_str, filtered_df
gc.collect()