""" Page for similarities """ ################ # DEPENDENCIES # ################ import streamlit as st import pandas as pd from scipy.sparse import load_npz import pickle import faiss from sentence_transformers import SentenceTransformer from modules.result_table import show_table from functions.filter_projects import filter_projects from functions.calc_matches import calc_matches import psutil import os import gc def get_process_memory(): process = psutil.Process(os.getpid()) return process.memory_info().rss / (1024 * 1024) # Catch DATA # Load Similarity matrix @st.cache_data def load_sim_matrix(): loaded_matrix = load_npz("src/extended_similarities.npz") dense_matrix = loaded_matrix.toarray() return dense_matrix # Load Non Similar Orga Matrix @st.cache_data def load_nonsameorga_sim_matrix(): loaded_matrix = load_npz("src/extended_similarities_nonsimorga.npz") dense_matrix = loaded_matrix.toarray() return dense_matrix # Load Projects DFs @st.cache_data def load_projects(): orgas_df = pd.read_csv("src/projects/project_orgas.csv") region_df = pd.read_csv("src/projects/project_region.csv") sector_df = pd.read_csv("src/projects/project_sector.csv") status_df = pd.read_csv("src/projects/project_status.csv") texts_df = pd.read_csv("src/projects/project_texts.csv") projects_df = pd.merge(orgas_df, region_df, on='iati_id', how='inner') projects_df = pd.merge(projects_df, sector_df, on='iati_id', how='inner') projects_df = pd.merge(projects_df, status_df, on='iati_id', how='inner') projects_df = pd.merge(projects_df, texts_df, on='iati_id', how='inner') return projects_df # Load CRS 3 data @st.cache_data def getCRS3(): # Read in CRS3 CODELISTS crs3_df = pd.read_csv('src/codelists/crs3_codes.csv') CRS3_CODES = crs3_df['code'].tolist() CRS3_NAME = crs3_df['name'].tolist() CRS3_MERGED = {f"{name} - {code}": code for name, code in zip(CRS3_NAME, CRS3_CODES)} return CRS3_MERGED # Load CRS 5 data @st.cache_data def getCRS5(): # Read in CRS3 CODELISTS crs5_df = pd.read_csv('src/codelists/crs5_codes.csv') CRS5_CODES = crs5_df['code'].tolist() CRS5_NAME = crs5_df['name'].tolist() CRS5_MERGED = {code: [f"{name} - {code}"] for name, code in zip(CRS5_NAME, CRS5_CODES)} return CRS5_MERGED # Load SDG data @st.cache_data def getSDG(): # Read in SDG CODELISTS sdg_df = pd.read_csv('src/codelists/sdg_goals.csv') SDG_NAMES = sdg_df['name'].tolist() return SDG_NAMES # Load Country Data @st.cache_data def getCountry(): # Read in countries from codelist country_df = pd.read_csv('src/codelists/country_codes_ISO3166-1alpha-2.csv') COUNTRY_CODES = country_df['Alpha-2 code'].tolist() COUNTRY_NAMES = country_df['Country'].tolist() COUNTRY_OPTION_LIST = [f"{COUNTRY_NAMES[i]} ({COUNTRY_CODES[i][-3:-1].upper()})"for i in range(len(COUNTRY_NAMES))] return COUNTRY_OPTION_LIST # Load Sentence Transformer Model @st.cache_resource def load_model(): model = SentenceTransformer('all-MiniLM-L6-v2') return model # Load Embeddings @st.cache_data def load_embeddings_and_index(): # Load embeddings with open("src/embeddings.pkl", "rb") as fIn: stored_data = pickle.load(fIn) embeddings = stored_data["embeddings"] return embeddings # USE CACHE FUNCTIONS sim_matrix = load_sim_matrix() nonsameorgas_sim_matrix = load_nonsameorga_sim_matrix() projects_df = load_projects() CRS3_MERGED = getCRS3() CRS5_MERGED = getCRS5() SDG_NAMES = getSDG() COUNTRY_OPTION_LIST = getCountry() # LOAD MODEL FROM CACHE FO SEMANTIC SEARCH model = load_model() embeddings = load_embeddings_and_index() def show_page(): st.write(f"Current RAM usage of this app: {get_process_memory():.2f} MB") st.write("Similarities") st.session_state.crs5_option_disabled = True col1, col2 = st.columns([1, 1]) with col1: # CRS 3 SELECTION crs3_option = st.multiselect( 'CRS 3', CRS3_MERGED, placeholder="Select" ) # CRS 5 SELECTION ## Only enable crs5 select field when crs3 code is selected if crs3_option != []: st.session_state.crs5_option_disabled = False ## define list of crs5 codes dependend on crs3 codes crs5_list = [txt[0].replace('"', "") for crs3_item in crs3_option for code, txt in CRS5_MERGED.items() if str(code)[:3] == str(crs3_item)[-3:]] ## crs5 select field crs5_option = st.multiselect( 'CRS 5', crs5_list, placeholder="Select", disabled=st.session_state.crs5_option_disabled ) # SDG SELECTION sdg_option = st.selectbox( label = 'SDG', index = None, placeholder = "Select SDG", options = SDG_NAMES[:-1], ) different_orga_checkbox = st.checkbox("Only matches between different organizations") with col2: # COUNTRY SELECTION country_option = st.multiselect( 'Country / Countries', COUNTRY_OPTION_LIST, placeholder="Select" ) # ORGA SELECTION orga_abbreviation = projects_df["orga_abbreviation"].unique() orga_full_names = projects_df["orga_full_name"].unique() orga_list = [f"{orga_full_names[i]} ({orga_abbreviation[i].upper()})"for i in range(len(orga_abbreviation))] orga_option = st.multiselect( 'Development Bank / Organization', orga_list, placeholder="Select" ) # SEARCH BOX query = st.text_input("Enter your search query:") # CRS CODE LIST crs3_list = [i[-3:] for i in crs3_option] crs5_list = [i[-5:] for i in crs5_option] # SDG CODE LIST if sdg_option != None: sdg_str = sdg_option[0] else: sdg_str = "" # COUNTRY CODES LIST country_code_list = [option[-3:-1] for option in country_option] # ORGANIZATION CODES LIST orga_code_list = [option.split("(")[1][:-1].lower() for option in orga_option] # FILTER DF WITH SELECTED FILTER OPTIONS TOP_X_PROJECTS = 30 filtered_df = filter_projects(projects_df, crs3_list, crs5_list, sdg_str, country_code_list, orga_code_list, query, model, embeddings, TOP_X_PROJECTS) #with col2: # Semantic Search #searched_filtered_df = semantic_search.show_search(model, embeddings, sentences, filtered_df, TOP_X_PROJECTS) if isinstance(filtered_df, pd.DataFrame): # FIND MATCHES if different_orga_checkbox: p1_df, p2_df = calc_matches(filtered_df, projects_df, nonsameorgas_sim_matrix, TOP_X_PROJECTS) else: p1_df, p2_df = calc_matches(filtered_df, projects_df, sim_matrix, TOP_X_PROJECTS) # SHOW THE RESULT show_table(p1_df, p2_df) del p1_df, p2_df else: st.write("Select at least on CRS 3, SDG or type in a query") del crs3_list, crs5_list, sdg_str, filtered_df gc.collect()