Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
""" | |
Page for similarities | |
""" | |
################ | |
# DEPENDENCIES # | |
################ | |
import streamlit as st | |
import pandas as pd | |
from scipy.sparse import load_npz | |
import pickle | |
import faiss | |
from sentence_transformers import SentenceTransformer | |
from modules.result_table import show_table | |
from functions.filter_projects import filter_projects | |
from functions.calc_matches import calc_matches | |
import psutil | |
import os | |
import gc | |
def get_process_memory(): | |
process = psutil.Process(os.getpid()) | |
return process.memory_info().rss / (1024 * 1024) | |
# Catch DATA | |
# Load Similarity matrix | |
def load_sim_matrix(): | |
loaded_matrix = load_npz("src/extended_similarities.npz") | |
dense_matrix = loaded_matrix.toarray() | |
return dense_matrix | |
# Load Non Similar Orga Matrix | |
def load_nonsameorga_sim_matrix(): | |
loaded_matrix = load_npz("src/extended_similarities_nonsimorga.npz") | |
dense_matrix = loaded_matrix.toarray() | |
return dense_matrix | |
# Load Projects DFs | |
def load_projects(): | |
orgas_df = pd.read_csv("src/projects/project_orgas.csv") | |
region_df = pd.read_csv("src/projects/project_region.csv") | |
sector_df = pd.read_csv("src/projects/project_sector.csv") | |
status_df = pd.read_csv("src/projects/project_status.csv") | |
texts_df = pd.read_csv("src/projects/project_texts.csv") | |
projects_df = pd.merge(orgas_df, region_df, on='iati_id', how='inner') | |
projects_df = pd.merge(projects_df, sector_df, on='iati_id', how='inner') | |
projects_df = pd.merge(projects_df, status_df, on='iati_id', how='inner') | |
projects_df = pd.merge(projects_df, texts_df, on='iati_id', how='inner') | |
return projects_df | |
# Load CRS 3 data | |
def getCRS3(): | |
# Read in CRS3 CODELISTS | |
crs3_df = pd.read_csv('src/codelists/crs3_codes.csv') | |
CRS3_CODES = crs3_df['code'].tolist() | |
CRS3_NAME = crs3_df['name'].tolist() | |
CRS3_MERGED = {f"{name} - {code}": code for name, code in zip(CRS3_NAME, CRS3_CODES)} | |
return CRS3_MERGED | |
# Load CRS 5 data | |
def getCRS5(): | |
# Read in CRS3 CODELISTS | |
crs5_df = pd.read_csv('src/codelists/crs5_codes.csv') | |
CRS5_CODES = crs5_df['code'].tolist() | |
CRS5_NAME = crs5_df['name'].tolist() | |
CRS5_MERGED = {code: [f"{name} - {code}"] for name, code in zip(CRS5_NAME, CRS5_CODES)} | |
return CRS5_MERGED | |
# Load SDG data | |
def getSDG(): | |
# Read in SDG CODELISTS | |
sdg_df = pd.read_csv('src/codelists/sdg_goals.csv') | |
SDG_NAMES = sdg_df['name'].tolist() | |
return SDG_NAMES | |
# Load Country Data | |
def getCountry(): | |
# Read in countries from codelist | |
country_df = pd.read_csv('src/codelists/country_codes_ISO3166-1alpha-2.csv') | |
COUNTRY_CODES = country_df['Alpha-2 code'].tolist() | |
COUNTRY_NAMES = country_df['Country'].tolist() | |
COUNTRY_OPTION_LIST = [f"{COUNTRY_NAMES[i]} ({COUNTRY_CODES[i][-3:-1].upper()})"for i in range(len(COUNTRY_NAMES))] | |
return COUNTRY_OPTION_LIST | |
# Load Sentence Transformer Model | |
def load_model(): | |
model = SentenceTransformer('all-MiniLM-L6-v2') | |
return model | |
""" | |
# Load Embeddings | |
@st.cache_data | |
def load_embeddings_and_index(): | |
# Load embeddings | |
with open("src/embeddings.pkl", "rb") as fIn: | |
stored_data = pickle.load(fIn) | |
embeddings = stored_data["embeddings"] | |
return embeddings | |
""" | |
# USE CACHE FUNCTIONS | |
sim_matrix = load_sim_matrix() | |
nonsameorgas_sim_matrix = load_nonsameorga_sim_matrix() | |
projects_df = load_projects() | |
CRS3_MERGED = getCRS3() | |
CRS5_MERGED = getCRS5() | |
SDG_NAMES = getSDG() | |
COUNTRY_OPTION_LIST = getCountry() | |
# LOAD MODEL FROM CACHE FO SEMANTIC SEARCH | |
model = load_model() | |
#embeddings = load_embeddings_and_index() | |
embeddings = [] | |
def show_page(): | |
st.write(f"Current RAM usage of this app: {get_process_memory():.2f} MB") | |
st.write("Similarities") | |
st.session_state.crs5_option_disabled = True | |
col1, col2 = st.columns([1, 1]) | |
with col1: | |
# CRS 3 SELECTION | |
crs3_option = st.multiselect( | |
'CRS 3', | |
CRS3_MERGED, | |
placeholder="Select" | |
) | |
# CRS 5 SELECTION | |
## Only enable crs5 select field when crs3 code is selected | |
if crs3_option != []: | |
st.session_state.crs5_option_disabled = False | |
## define list of crs5 codes dependend on crs3 codes | |
crs5_list = [txt[0].replace('"', "") for crs3_item in crs3_option for code, txt in CRS5_MERGED.items() if str(code)[:3] == str(crs3_item)[-3:]] | |
## crs5 select field | |
crs5_option = st.multiselect( | |
'CRS 5', | |
crs5_list, | |
placeholder="Select", | |
disabled=st.session_state.crs5_option_disabled | |
) | |
# SDG SELECTION | |
sdg_option = st.selectbox( | |
label = 'SDG', | |
index = None, | |
placeholder = "Select SDG", | |
options = SDG_NAMES[:-1], | |
) | |
different_orga_checkbox = st.checkbox("Only matches between different organizations") | |
with col2: | |
# COUNTRY SELECTION | |
country_option = st.multiselect( | |
'Country / Countries', | |
COUNTRY_OPTION_LIST, | |
placeholder="Select" | |
) | |
# ORGA SELECTION | |
orga_abbreviation = projects_df["orga_abbreviation"].unique() | |
orga_full_names = projects_df["orga_full_name"].unique() | |
orga_list = [f"{orga_full_names[i]} ({orga_abbreviation[i].upper()})"for i in range(len(orga_abbreviation))] | |
orga_option = st.multiselect( | |
'Development Bank / Organization', | |
orga_list, | |
placeholder="Select" | |
) | |
# SEARCH BOX | |
query = st.text_input("Enter your search query:") | |
# CRS CODE LIST | |
crs3_list = [i[-3:] for i in crs3_option] | |
crs5_list = [i[-5:] for i in crs5_option] | |
# SDG CODE LIST | |
if sdg_option != None: | |
sdg_str = sdg_option[0] | |
else: | |
sdg_str = "" | |
# COUNTRY CODES LIST | |
country_code_list = [option[-3:-1] for option in country_option] | |
# ORGANIZATION CODES LIST | |
orga_code_list = [option.split("(")[1][:-1].lower() for option in orga_option] | |
# FILTER DF WITH SELECTED FILTER OPTIONS | |
TOP_X_PROJECTS = 30 | |
filtered_df = filter_projects(projects_df, crs3_list, crs5_list, sdg_str, country_code_list, orga_code_list, query, model, embeddings, TOP_X_PROJECTS) | |
#with col2: | |
# Semantic Search | |
#searched_filtered_df = semantic_search.show_search(model, embeddings, sentences, filtered_df, TOP_X_PROJECTS) | |
if isinstance(filtered_df, pd.DataFrame): | |
# FIND MATCHES | |
if different_orga_checkbox: | |
p1_df, p2_df = calc_matches(filtered_df, projects_df, nonsameorgas_sim_matrix, TOP_X_PROJECTS) | |
else: | |
p1_df, p2_df = calc_matches(filtered_df, projects_df, sim_matrix, TOP_X_PROJECTS) | |
# SHOW THE RESULT | |
show_table(p1_df, p2_df) | |
del p1_df, p2_df | |
else: | |
st.write("Select at least on CRS 3, SDG or type in a query") | |
del crs3_list, crs5_list, sdg_str, filtered_df | |
gc.collect() |