import streamlit as st
from menu import menu_with_redirect

# Standard imports
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F

# Path manipulation
from pathlib import Path
from huggingface_hub import hf_hub_download

# Plotting
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = 'Arial'

# Custom and other imports
import project_config
from utils import capitalize_after_slash, load_kg

# Redirect to app.py if not logged in, otherwise show the navigation menu
menu_with_redirect()

# Header
st.image(str(project_config.MEDIA_DIR / 'predict_header.svg'), use_column_width=True)

# Main content
# st.markdown(f"Hello, {st.session_state.name}!")

st.subheader(f"{capitalize_after_slash(st.session_state.query['target_node_type'])} Search", divider = "blue")

# Print current query
st.markdown(f"**Query:** {st.session_state.query['source_node'].replace('_', ' ')} ➡️ {st.session_state.query['relation'].replace('_', '-')} ➡️ {st.session_state.query['target_node_type'].replace('_', ' ')}")

@st.cache_data(show_spinner = 'Downloading AI model...')
def get_embeddings():
    # Get checkpoint name
    # best_ckpt = "2024_05_22_11_59_43_epoch=18-step=22912"
    best_ckpt = "2024_05_15_13_05_33_epoch=2-step=40383"
    # best_ckpt = "2024_03_29_04_12_52_epoch=3-step=54291"

    # Get paths to embeddings, relation weights, and edge types
    # with st.spinner('Downloading AI model...'):
    embed_path = hf_hub_download(repo_id="ayushnoori/galaxy",
                                filename=(best_ckpt + "-thresh=4000_embeddings.pt"),
                                token=st.secrets["HF_TOKEN"])
    relation_weights_path = hf_hub_download(repo_id="ayushnoori/galaxy",
                                            filename=(best_ckpt + "_relation_weights.pt"),
                                            token=st.secrets["HF_TOKEN"])
    edge_types_path = hf_hub_download(repo_id="ayushnoori/galaxy",
                                        filename=(best_ckpt + "_edge_types.pt"),
                                        token=st.secrets["HF_TOKEN"])
    return embed_path, relation_weights_path, edge_types_path

@st.cache_data(show_spinner = 'Loading AI model...')
def load_embeddings(embed_path, relation_weights_path, edge_types_path):
    # Load embeddings, relation weights, and edge types
    # with st.spinner('Loading AI model...'):
    embeddings = torch.load(embed_path)
    relation_weights = torch.load(relation_weights_path)
    edge_types = torch.load(edge_types_path)

    return embeddings, relation_weights, edge_types

# Load knowledge graph and embeddings
kg_nodes = load_kg()
embed_path, relation_weights_path, edge_types_path = get_embeddings()
embeddings, relation_weights, edge_types = load_embeddings(embed_path, relation_weights_path, edge_types_path)

# # Print source node type
# st.write(f"Source Node Type: {st.session_state.query['source_node_type']}")

# # Print source node
# st.write(f"Source Node: {st.session_state.query['source_node']}")

# # Print relation
# st.write(f"Edge Type: {st.session_state.query['relation']}")

# # Print target node type
# st.write(f"Target Node Type: {st.session_state.query['target_node_type']}")

# Compute predictions
with st.spinner('Computing predictions...'):

    source_node_type = st.session_state.query['source_node_type']
    source_node = st.session_state.query['source_node']
    relation = st.session_state.query['relation']
    target_node_type = st.session_state.query['target_node_type']

    # Get source node index
    src_index = kg_nodes[(kg_nodes.node_type == source_node_type) & (kg_nodes.node_name == source_node)].node_index.values[0]

    # Get relation index
    edge_type_index = [i for i, etype in enumerate(edge_types) if etype == (source_node_type, relation, target_node_type)][0]

    # Get target nodes indices
    target_nodes = kg_nodes[kg_nodes.node_type == target_node_type].copy()
    dst_indices = target_nodes.node_index.values
    src_indices = np.repeat(src_index, len(dst_indices))

    # Retrieve cached embeddings and apply activation function
    src_embeddings = embeddings[src_indices]
    dst_embeddings = embeddings[dst_indices]
    src_embeddings = F.leaky_relu(src_embeddings)
    dst_embeddings = F.leaky_relu(dst_embeddings)

    # Get relation weights
    rel_weights = relation_weights[edge_type_index]

    # Compute weighted dot product
    scores = torch.sum(src_embeddings * rel_weights * dst_embeddings, dim = 1)
    scores = torch.sigmoid(scores)

    # Add scores to dataframe
    target_nodes['score'] = scores.detach().numpy()
    target_nodes = target_nodes.sort_values(by = 'score', ascending = False)
    target_nodes['rank'] = np.arange(1, target_nodes.shape[0] + 1)

    # Rename columns
    display_data = target_nodes[['rank', 'node_id', 'node_name', 'score', 'node_source']].copy()
    display_data = display_data.rename(columns = {'rank': 'Rank', 'node_id': 'ID', 'node_name': 'Name', 'score': 'Score', 'node_source': 'Database'})

    # Define dictionary mapping node types to database URLs
    map_dbs = {
        'gene/protein': lambda x: f"https://ncbi.nlm.nih.gov/gene/?term={x}",
        'drug': lambda x: f"https://go.drugbank.com/drugs/{x}",
        'effect/phenotype': lambda x: f"https://hpo.jax.org/app/browse/term/HP:{x.zfill(7)}", # pad with 0s to 7 digits
        'disease': lambda x: x, # MONDO
        # pad with 0s to 7 digits
        'biological_process': lambda x: f"https://amigo.geneontology.org/amigo/term/GO:{x.zfill(7)}", 
        'molecular_function': lambda x: f"https://amigo.geneontology.org/amigo/term/GO:{x.zfill(7)}",
        'cellular_component': lambda x: f"https://amigo.geneontology.org/amigo/term/GO:{x.zfill(7)}",
        'exposure': lambda x: f"https://ctdbase.org/detail.go?type=chem&acc={x}",
        'pathway': lambda x: f"https://reactome.org/content/detail/{x}",
        'anatomy': lambda x: x,
    }

    # Get name of database
    display_database = display_data['Database'].values[0] 

    # Add URLs to database column
    display_data['Database'] = display_data.apply(lambda x: map_dbs[target_node_type](x['ID']), axis = 1)

    # Check if validation data exists
    if 'validation' in st.session_state:

        # Checkbox to allow reverse edges
        show_val = st.checkbox("Show Ground Truth Validation?", value = False)

        if show_val:

            # Get validation data
            val_results = st.session_state.validation.copy()

            # Merge with predictions
            val_display_data = pd.merge(display_data, val_results, left_on = 'ID', right_on = 'y_id', how='left')
            val_display_data = val_display_data.fillna(0).drop(columns='y_id')

            # Get new columns
            val_relations = val_display_data.columns.difference(display_data.columns).tolist()

            # Replace 0 with blank and 1 with check emoji in new columns
            for col in val_relations:
                val_display_data[col] = val_display_data[col].replace({0: '', 1: '✅'})

            # Define a function to apply styles
            def style_val(val):
                if val == '✅':
                    return 'background-color: #C2EABD;' #  text-align: center;
                return 'background-color: #F5F5F5;' # text-align: center;

    else:
        show_val = False


    # NODE SEARCH

    # Use multiselect to search for specific nodes
    selected_nodes = st.multiselect(f"Search for specific {target_node_type.replace('_', ' ')} nodes to determine their ranking.",
                                    display_data.Name, placeholder = "Type to search...")

    # Filter nodes
    if len(selected_nodes) > 0:

        if show_val:
            # selected_display_data = val_display_data[val_display_data.Name.isin(selected_nodes)]
            selected_display_data = val_display_data[val_display_data.Name.isin(selected_nodes)].copy()
            selected_display_data = selected_display_data.reset_index(drop=True).style.map(style_val, subset=val_relations)
        else:
            selected_display_data = display_data[display_data.Name.isin(selected_nodes)].copy()
            selected_display_data = selected_display_data.reset_index(drop=True)

        st.markdown(f"Out of {target_nodes.shape[0]} {target_node_type} nodes, the selected nodes rank as follows:")
        selected_display_data_with_rank = selected_display_data.copy()
        selected_display_data_with_rank['Rank'] = selected_display_data_with_rank['Rank'].apply(lambda x: f"{x} (top {(100*x/target_nodes.shape[0]):.2f}% of predictions)")

        # Show filtered nodes
        if target_node_type not in ['disease', 'anatomy']:
            st.dataframe(selected_display_data_with_rank, use_container_width = True, hide_index = True,
                        column_config={"Database": st.column_config.LinkColumn(width = "small",
                                                                               help = "Click to visit external database.",
                                                                               display_text = display_database)})
        else:
            st.dataframe(selected_display_data_with_rank, use_container_width = True)

        # Show plot
        st.markdown(f"In the plot below, the dashed lines represent the rank of the selected {target_node_type} nodes across all predictions for {source_node}.")

        # Checkbox to show text labels
        show_labels = st.checkbox("Show Text Labels?", value = False)

        # Plot rank vs. score using matplotlib
        fig, ax = plt.subplots(figsize = (10, 6))
        ax.plot(display_data['Rank'], display_data['Score'], color = 'black', linewidth = 1.5, zorder = 2)
        ax.set_xlabel('Rank', fontsize = 12)
        ax.set_ylabel('Score', fontsize = 12)
        ax.set_xlim(1, display_data['Rank'].max())

        # Get color palette
        # palette = plt.cm.get_cmap('tab10', len(selected_display_data))
        palette = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf"]

        # Add vertical line for selected nodes
        for i, node in selected_display_data.iterrows():
            ax.scatter(node['Rank'], node['Score'], color = palette[i], zorder=3)
            ax.axvline(node['Rank'], color = palette[i], linestyle = '--', linewidth = 1.5, label = node['Name'], zorder=3)
            if show_labels:
                ax.text(node['Rank'] + 100, node['Score'], node['Name'], fontsize = 10, color = palette[i], zorder=3)

        # Add legend
        ax.legend(loc = 'upper right', fontsize = 10)
        ax.grid(alpha = 0.2, zorder=0)

        st.pyplot(fig)
    
    
    # FULL RESULTS

    # Show top ranked nodes
    st.subheader("Model Predictions", divider = "blue")
    top_k = st.slider('Select number of top ranked nodes to show.', 1, target_nodes.shape[0], min(500, target_nodes.shape[0])) 

    # Show full results
    # full_results = val_display_data.iloc[:top_k] if show_val else display_data.iloc[:top_k]
    full_results = val_display_data.iloc[:top_k].style.map(style_val, subset=val_relations) if show_val else display_data.iloc[:top_k]
    
    if target_node_type not in ['disease', 'anatomy']:
        st.dataframe(full_results, use_container_width = True, hide_index = True,
                    column_config={"Database": st.column_config.LinkColumn(width = "small",
                                                                           help = "Click to visit external database.",
                                                                           display_text = display_database)})
    else:
        st.dataframe(full_results, use_container_width = True, hide_index = True,)

    # Save to session state
    st.session_state.predictions = display_data
    st.session_state.display_database = display_database