""" |
Page for similarities |
""" |
import streamlit as st |
import pandas as pd |
from scipy.sparse import load_npz |
import pickle |
import faiss |
from sentence_transformers import SentenceTransformer |
from modules.result_table import show_table |
from functions.filter_projects import filter_projects |
from functions.calc_matches import calc_matches |
import psutil |
import os |
import gc |
def get_process_memory(): |
process = psutil.Process(os.getpid()) |
return process.memory_info().rss / (1024 * 1024) |
@st.cache_data |
def load_sim_matrix(): |
loaded_matrix = load_npz("src/extended_similarities.npz") |
dense_matrix = loaded_matrix.toarray() |
return dense_matrix |
@st.cache_data |
def load_nonsameorga_sim_matrix(): |
loaded_matrix = load_npz("src/extended_similarities_nonsimorga.npz") |
dense_matrix = loaded_matrix.toarray() |
return dense_matrix |
@st.cache_data |
def load_projects(): |
orgas_df = pd.read_csv("src/projects/project_orgas.csv") |
region_df = pd.read_csv("src/projects/project_region.csv") |
sector_df = pd.read_csv("src/projects/project_sector.csv") |
status_df = pd.read_csv("src/projects/project_status.csv") |
texts_df = pd.read_csv("src/projects/project_texts.csv") |
projects_df = pd.merge(orgas_df, region_df, on='iati_id', how='inner') |
projects_df = pd.merge(projects_df, sector_df, on='iati_id', how='inner') |
projects_df = pd.merge(projects_df, status_df, on='iati_id', how='inner') |
projects_df = pd.merge(projects_df, texts_df, on='iati_id', how='inner') |
return projects_df |
@st.cache_data |
def getCRS3(): |
crs3_df = pd.read_csv('src/codelists/crs3_codes.csv') |
CRS3_CODES = crs3_df['code'].tolist() |
CRS3_NAME = crs3_df['name'].tolist() |
CRS3_MERGED = {f"{name} - {code}": code for name, code in zip(CRS3_NAME, CRS3_CODES)} |
return CRS3_MERGED |
@st.cache_data |
def getCRS5(): |
crs5_df = pd.read_csv('src/codelists/crs5_codes.csv') |
CRS5_CODES = crs5_df['code'].tolist() |
CRS5_NAME = crs5_df['name'].tolist() |
CRS5_MERGED = {code: [f"{name} - {code}"] for name, code in zip(CRS5_NAME, CRS5_CODES)} |
return CRS5_MERGED |
@st.cache_data |
def getSDG(): |
sdg_df = pd.read_csv('src/codelists/sdg_goals.csv') |
SDG_NAMES = sdg_df['name'].tolist() |
return SDG_NAMES |
@st.cache_data |
def getCountry(): |
country_df = pd.read_csv('src/codelists/country_codes_ISO3166-1alpha-2.csv') |
COUNTRY_CODES = country_df['Alpha-2 code'].tolist() |
COUNTRY_NAMES = country_df['Country'].tolist() |
COUNTRY_OPTION_LIST = [f"{COUNTRY_NAMES[i]} ({COUNTRY_CODES[i][-3:-1].upper()})"for i in range(len(COUNTRY_NAMES))] |
@st.cache_resource |
def load_model(): |
model = SentenceTransformer('all-MiniLM-L6-v2') |
return model |
@st.cache_data |
def load_embeddings_and_index(): |
with open("src/embeddings.pkl", "rb") as fIn: |
stored_data = pickle.load(fIn) |
embeddings = stored_data["embeddings"] |
return embeddings |
sim_matrix = load_sim_matrix() |
nonsameorgas_sim_matrix = load_nonsameorga_sim_matrix() |
projects_df = load_projects() |
CRS3_MERGED = getCRS3() |
CRS5_MERGED = getCRS5() |
SDG_NAMES = getSDG() |
COUNTRY_OPTION_LIST = getCountry() |
model = load_model() |
embeddings = load_embeddings_and_index() |
def show_page(): |
st.write(f"Current RAM usage of this app: {get_process_memory():.2f} MB") |
st.write("Similarities") |
st.session_state.crs5_option_disabled = True |
col1, col2 = st.columns([1, 1]) |
with col1: |
crs3_option = st.multiselect( |
'CRS 3', |
placeholder="Select" |
) |
if crs3_option != []: |
st.session_state.crs5_option_disabled = False |
crs5_list = [txt[0].replace('"', "") for crs3_item in crs3_option for code, txt in CRS5_MERGED.items() if str(code)[:3] == str(crs3_item)[-3:]] |
crs5_option = st.multiselect( |
'CRS 5', |
crs5_list, |
placeholder="Select", |
disabled=st.session_state.crs5_option_disabled |
) |
sdg_option = st.selectbox( |
label = 'SDG', |
index = None, |
placeholder = "Select SDG", |
options = SDG_NAMES[:-1], |
) |
different_orga_checkbox = st.checkbox("Only matches between different organizations") |
with col2: |
country_option = st.multiselect( |
'Country / Countries', |
placeholder="Select" |
) |
orga_abbreviation = projects_df["orga_abbreviation"].unique() |
orga_full_names = projects_df["orga_full_name"].unique() |
orga_list = [f"{orga_full_names[i]} ({orga_abbreviation[i].upper()})"for i in range(len(orga_abbreviation))] |
orga_option = st.multiselect( |
'Development Bank / Organization', |
orga_list, |
placeholder="Select" |
) |
query = st.text_input("Enter your search query:") |
crs3_list = [i[-3:] for i in crs3_option] |
crs5_list = [i[-5:] for i in crs5_option] |
if sdg_option != None: |
sdg_str = sdg_option[0] |
else: |
sdg_str = "" |
country_code_list = [option[-3:-1] for option in country_option] |
orga_code_list = [option.split("(")[1][:-1].lower() for option in orga_option] |
filtered_df = filter_projects(projects_df, crs3_list, crs5_list, sdg_str, country_code_list, orga_code_list, query, model, embeddings, TOP_X_PROJECTS) |
if isinstance(filtered_df, pd.DataFrame): |
if different_orga_checkbox: |
p1_df, p2_df = calc_matches(filtered_df, projects_df, nonsameorgas_sim_matrix, TOP_X_PROJECTS) |
else: |
p1_df, p2_df = calc_matches(filtered_df, projects_df, sim_matrix, TOP_X_PROJECTS) |
show_table(p1_df, p2_df) |
del p1_df, p2_df |
else: |
st.write("Select at least on CRS 3, SDG or type in a query") |
del crs3_list, crs5_list, sdg_str, filtered_df |
gc.collect() |