import gradio as gr
import pandas as pd
import numpy as np
import json
from io import StringIO
from collections import OrderedDict
import os

# ---------------------- Accessing data from Notion ---------------------- #


from notion_client import Client as client_notion
from config import landuseDatabaseId , subdomainAttributesDatabaseId 
from imports_utils import fetch_all_database_pages
from imports_utils import get_property_value
from imports_utils import notion
from config import landuseColumnName 
from config import subdomainColumnName 
from config import sqmPerEmployeeColumnName
from config import thresholdsColumnName 
from config import maxPointsColumnName
from config import domainColumnName 
from imports_utils import fetchDomainMapper
from imports_utils import fetchSubdomainMapper

from imports_utils import notionToken

if notionToken is None:
    raise Exception("Notion token not found. Please check the environment variables.")
else:
    print("Notion token found successfully!")
    landuse_attributes  = fetch_all_database_pages(notion, landuseDatabaseId)
    livability_attributes  = fetch_all_database_pages(notion, subdomainAttributesDatabaseId)
    landuseMapperDict = fetchDomainMapper (landuse_attributes)
    livabilityMapperDict = fetchSubdomainMapper (livability_attributes)




# ---------------------- Accessing data from Speckle ---------------------- #


from specklepy.api.client import SpeckleClient
from specklepy.api.credentials import get_default_account, get_local_accounts
from specklepy.transports.server import ServerTransport
from specklepy.api import operations
from specklepy.objects.geometry import Polyline, Point
from specklepy.objects import Base

import imports_utils
import speckle_utils
import data_utils

from config import landuseDatabaseId , streamId,  dmBranchName, dmCommitId, luBranchName, luCommitId
from imports_utils import speckleToken
from imports_utils import fetchDistanceMatrices
from config import distanceMatrixActivityNodes
from config import distanceMatrixTransportStops


if speckleToken is None:
    raise Exception("Speckle token not found")
else:
    print("Speckle token found successfully!")
    
    CLIENT = SpeckleClient(host="https://speckle.xyz/")
    account = get_default_account()
    CLIENT.authenticate_with_token(token=speckleToken)
    
    streamDistanceMatrices = speckle_utils.getSpeckleStream(streamId,dmBranchName,CLIENT, dmCommitId)
    matrices = fetchDistanceMatrices (streamDistanceMatrices)
    streamLanduses = speckle_utils.getSpeckleStream(streamId,luBranchName,CLIENT, luCommitId)
    streamData = streamLanduses["@Data"]["@{0}"]
    
    df_speckle_lu = speckle_utils.get_dataframe(streamData, return_original_df=False)
    df_lu = df_speckle_lu.copy()
    df_lu = df_lu.astype(str)
    df_lu =  df_lu.set_index("ids", drop=False)
    
    df_dm = matrices[distanceMatrixActivityNodes]
    # Replace infinity with 10000 and NaN values with 0, then convert to integers
    dfLanduses = dfLanduses.replace([np.inf, -np.inf], 10000).fillna(0)
    df_dm = df_dm.apply(pd.to_numeric, errors='coerce')
    df_dm = df_dm.round(0).astype(int)
    
    #df_dm_transport = matrices[distanceMatrixTransportStops]
    dm_dictionary = df_dm.to_dict('index')
    #df_dm_transport_dictionary = df_dm_transport.to_dict('index')
    
    # filter activity nodes attributes
    mask_connected = df_dm.index.tolist()
    lu_columns = []
    for name in df_lu.columns:
      if name.startswith("lu+"):
        lu_columns.append(name)
    
    df_lu_filtered = df_lu[lu_columns].loc[mask_connected]
    df_lu_filtered.columns = [col.replace('lu+', '') for col in df_lu_filtered.columns]
    df_lu_filtered.columns = [col.replace('ASSETS+', '') for col in df_lu_filtered.columns]
    
    df_lu_filtered = df_lu_filtered.apply(pd.to_numeric, errors='coerce')
    df_lu_filtered = df_lu_filtered.astype(int)
    df_lu_filtered = df_lu_filtered.T.groupby(level=0).sum().T
    
    df_lu_filtered_dict = df_lu_filtered.to_dict('index')




def test(input_json):
    print("Received input")
    # Parse the input JSON string
    try:
        inputs = json.loads(input_json)
    except json.JSONDecodeError:
        inputs = json.loads(input_json.replace("'", '"'))

    
    # ------------------------- Accessing input data from Grasshopper ------------------------- #

    from config import useGrasshopperData 

    if useGrasshopperData == True:
        matrix = inputs['input']["matrix"]
        landuses = inputs['input']["landuse_areas"]    
    
        dfLanduses = pd.DataFrame(landuses).T
        dfLanduses = dfLanduses.apply(pd.to_numeric, errors='coerce')
        dfLanduses = dfLanduses.round(0).astype(int)

        dfMatrix = pd.DataFrame(matrix).T
        dfMatrix = dfMatrix.apply(pd.to_numeric, errors='coerce')
        dfMatrix = dfMatrix.round(0).astype(int)
    else:
        dfLanduses = df_lu_filtered.copy()
        #dfLanduses = dfLanduses.round(0).astype(int)        
        dfMatrix = df_dm.copy()
        
    

    attributeMapperDict_gh = inputs['input']["attributeMapperDict"]
    landuseMapperDict_gh = inputs['input']["landuseMapperDict"]

    if not inputs['input']["alpha"]:
        from imports_utils import alpha
    else:
        alpha = inputs['input']["alpha"]
        alpha = float(alpha)
        
    if not inputs['input']["threshold"]:
        from imports_utils import threshold
    else:
        threshold = inputs['input']["threshold"]
        threshold = float(threshold)

        
          
    


    from imports_utils import splitDictByStrFragmentInColumnName

    """
    # List containing the substrings to check against
    tranportModes = ["DRT", "GMT", "HSR"]

    result_dicts = splitDictByStrFragmentInColumnName(df_dm_transport_dictionary, tranportModes)

    # Accessing each dictionary
    art_dict = result_dicts["DRT"]
    gmt_dict = result_dicts["GMT"]

    df_art_matrix = pd.DataFrame(art_dict).T
    df_art_matrix = df_art_matrix.round(0).astype(int)  
    df_gmt_matrix = pd.DataFrame(gmt_dict).T
    df_gmt_matrix = df_art_matrix.round(0).astype(int)     

    """

    # create a mask based on the matrix size and ids, crop activity nodes to the mask
    mask_connected = dfMatrix.index.tolist()

    valid_indexes = [idx for idx in mask_connected if idx in dfLanduses.index]
    # Identify and report missing indexes
    missing_indexes = set(mask_connected) - set(valid_indexes)
    if missing_indexes:
        print(f"Error: The following indexes were not found in the DataFrame: {missing_indexes}, length: {len(missing_indexes)}")
    
    # Apply the filtered mask
    dfLanduses_filtered = dfLanduses.loc[valid_indexes]


    from imports_utils import findUniqueDomains
    from imports_utils import findUniqueSubdomains
    
    from imports_utils import landusesToSubdomains
    from imports_utils import FindWorkplacesNumber
    from imports_utils import computeAccessibility
    from imports_utils import computeAccessibility_pointOfInterest    
    from imports_utils import remap  
    from imports_utils import accessibilityToLivability

    
    domainsUnique = findUniqueDomains(livabilityMapperDict)
    subdomainsUnique = findUniqueSubdomains(landuseMapperDict)
       
    LivabilitySubdomainsWeights = landusesToSubdomains(dfMatrix,df_lu_filtered,landuseMapperDict,subdomainsUnique)
    
    WorkplacesNumber = FindWorkplacesNumber(dfMatrix,livabilityMapperDict,LivabilitySubdomainsWeights,subdomainsUnique)
    
    # prepare an input weights dataframe for the parameter LivabilitySubdomainsInputs
    LivabilitySubdomainsInputs =pd.concat([LivabilitySubdomainsWeights, WorkplacesNumber], axis=1)
   
    subdomainsAccessibility = computeAccessibility(dfMatrix,LivabilitySubdomainsInputs,alpha,threshold)   
    #artAccessibility = computeAccessibility_pointOfInterest(df_art_matrix,'ART',alpha,threshold)
    #gmtAccessibility = computeAccessibility_pointOfInterest(df_gmt_matrix,'GMT+HSR',alpha,threshold)
    
    #AccessibilityInputs = pd.concat([subdomainsAccessibility, artAccessibility,gmtAccessibility], axis=1)
        

    if 'jobs' not in subdomainsAccessibility.columns:
        print("Error: Column 'jobs' does not exist in the subdomainsAccessibility.")

    livability = accessibilityToLivability(dfMatrix,subdomainsAccessibility,livabilityMapperDict,domainsUnique)
    

    livability_dictionary = livability.to_dict('index')
    LivabilitySubdomainsInputs_dictionary = LivabilitySubdomainsInputs.to_dict('index')
    subdomainsAccessibility_dictionary = AccessibilityInputs.to_dict('index')
    LivabilitySubdomainsWeights_dictionary = LivabilitySubdomainsWeights.to_dict('index')
    
    
    # Prepare the output
    output = {
        "subdomainsAccessibility_dictionary": subdomainsAccessibility_dictionary,
        "livability_dictionary": livability_dictionary,
        "subdomainsWeights_dictionary": LivabilitySubdomainsInputs_dictionary,
        "luDomainMapper": landuseMapperDict,
        "attributeMapper": livabilityMapperDict,
        "fetchDm": dm_dictionary,
        "landuses":df_lu_filtered_dict 
    }


    
    return json.dumps(output)

    # Define the Gradio interface with a single JSON input
iface = gr.Interface(
    fn=test,
    inputs=gr.Textbox(label="Input JSON", lines=20, placeholder="Enter JSON with all parameters here..."),
    outputs=gr.JSON(label="Output JSON"),
    title="testspace"
)

iface.launch()