# Standard library imports
import datetime
import base64
import os

# Related third-party imports
import streamlit as st
from streamlit_elements import elements
from google_auth_oauthlib.flow import Flow
from googleapiclient.discovery import build
from dotenv import load_dotenv
import pandas as pd
import searchconsole
import cohere
from sklearn.metrics.pairwise import cosine_similarity
import requests
from bs4 import BeautifulSoup

load_dotenv()

# Initialize Cohere client
COHERE_API_KEY = os.environ["COHERE_API_KEY"]
co = cohere.Client(COHERE_API_KEY)

# Configuration: Set to True if running locally, False if running on Streamlit Cloud
IS_LOCAL = False

# Constants
SEARCH_TYPES = ["web", "image", "video", "news", "discover", "googleNews"]
DATE_RANGE_OPTIONS = [
    "Last 7 Days",
    "Last 30 Days",
    "Last 3 Months",
    "Last 6 Months",
    "Last 12 Months",
    "Last 16 Months",
    "Custom Range"
]
DEVICE_OPTIONS = ["All Devices", "desktop", "mobile", "tablet"]
BASE_DIMENSIONS = ["page", "query", "country", "date"]
MAX_ROWS = 250_000
DF_PREVIEW_ROWS = 100

# -------------
# Streamlit App Configuration
# -------------

def setup_streamlit():
    st.set_page_config(page_title="✨ Simple Google Search Console Data | LeeFoot.co.uk", layout="wide")
    st.title("✨ Simple Google Search Console Data | June 2024")
    st.markdown(f"### Lightweight GSC Data Extractor. (Max {MAX_ROWS:,} Rows)")
    st.markdown(
        """
        <p>
            Created by <a href="https://twitter.com/LeeFootSEO" target="_blank">LeeFootSEO</a> |
            <a href="https://leefoot.co.uk" target="_blank">More Apps & Scripts on my Website</a>
        """,
        unsafe_allow_html=True
    )
    st.divider()

def init_session_state():
    if 'selected_property' not in st.session_state:
        st.session_state.selected_property = None
    if 'selected_search_type' not in st.session_state:
        st.session_state.selected_search_type = 'web'
    if 'selected_date_range' not in st.session_state:
        st.session_state.selected_date_range = 'Last 7 Days'
    if 'start_date' not in st.session_state:
        st.session_state.start_date = datetime.date.today() - datetime.timedelta(days=7)
    if 'end_date' not in st.session_state:
        st.session_state.end_date = datetime.date.today()
    if 'selected_dimensions' not in st.session_state:
        st.session_state.selected_dimensions = ['page', 'query']
    if 'selected_device' not in st.session_state:
        st.session_state.selected_device = 'All Devices'
    if 'custom_start_date' not in st.session_state:
        st.session_state.custom_start_date = datetime.date.today() - datetime.timedelta(days=7)
    if 'custom_end_date' not in st.session_state:
        st.session_state.custom_end_date = datetime.date.today()

# -------------
# Data Processing Functions
# -------------

def fetch_content(url):
    try:
        response = requests.get(url)
        response.raise_for_status()
        soup = BeautifulSoup(response.text, 'html.parser')
        content = soup.get_text(separator=' ', strip=True)
        return content
    except requests.RequestException as e:
        return str(e)

def generate_embeddings(text_list):
    if not text_list:
        return []
    model = 'embed-english-v3.0'
    input_type = 'search_document'
    response = co.embed(model=model, texts=text_list, input_type=input_type)
    embeddings = response.embeddings
    return embeddings

def calculate_relevancy_scores(df):
    try:
        page_contents = [fetch_content(url) for url in df['page']]
        page_embeddings = generate_embeddings(page_contents)
        query_embeddings = generate_embeddings(df['query'].tolist())
        relevancy_scores = cosine_similarity(query_embeddings, page_embeddings).diagonal()
        df = df.assign(relevancy_score=relevancy_scores)
    except Exception as e:
        st.warning(f"Error calculating relevancy scores: {e}")
        df = df.assign(relevancy_score=0)
    return df

def process_gsc_data(df):
    df_sorted = df.sort_values(['page', 'clicks'], ascending=[True, False])
    df_unique = df_sorted.drop_duplicates(subset='page', keep='first').copy()
    if 'relevancy_score' not in df_unique.columns:
        df_unique['relevancy_score'] = 0
    else:
        df_unique['relevancy_score'] = df_sorted.groupby('page')['relevancy_score'].first().values
    result = df_unique[['page', 'query', 'clicks', 'impressions', 'ctr', 'position', 'relevancy_score']]
    return result

# -------------
# Google Authentication Functions
# -------------

def load_config():
    client_config = {
        "web": {
            "client_id": os.environ["CLIENT_ID"],
            "client_secret": os.environ["CLIENT_SECRET"],
            "auth_uri": "https://accounts.google.com/o/oauth2/auth",
            "token_uri": "https://oauth2.googleapis.com/token",
            "redirect_uris": ["https://poemsforaphrodite-gscpro.hf.space/"],
        }
    }
    return client_config

def init_oauth_flow(client_config):
    scopes = ["https://www.googleapis.com/auth/webmasters.readonly"]
    flow = Flow.from_client_config(
        client_config,
        scopes=scopes,
        redirect_uri=client_config["web"]["redirect_uris"][0]
    )
    return flow

def google_auth(client_config):
    flow = init_oauth_flow(client_config)
    auth_url, _ = flow.authorization_url(prompt="consent")
    return flow, auth_url

def auth_search_console(client_config, credentials):
    token = {
        "token": credentials.token,
        "refresh_token": credentials.refresh_token,
        "token_uri": credentials.token_uri,
        "client_id": credentials.client_id,
        "client_secret": credentials.client_secret,
        "scopes": credentials.scopes,
        "id_token": getattr(credentials, "id_token", None),
    }
    return searchconsole.authenticate(client_config=client_config, credentials=token)

# -------------
# Data Fetching Functions
# -------------

def list_gsc_properties(credentials):
    service = build('webmasters', 'v3', credentials=credentials)
    site_list = service.sites().list().execute()
    return [site['siteUrl'] for site in site_list.get('siteEntry', [])] or ["No properties found"]

def fetch_gsc_data(webproperty, search_type, start_date, end_date, dimensions, device_type=None):
    query = webproperty.query.range(start_date, end_date).search_type(search_type).dimension(*dimensions)
    if 'device' in dimensions and device_type and device_type != 'All Devices':
        query = query.filter('device', 'equals', device_type.lower())
    try:
        df = query.limit(MAX_ROWS).get().to_dataframe()
        return process_gsc_data(df)
    except Exception as e:
        show_error(e)
        return pd.DataFrame()

def fetch_data_loading(webproperty, search_type, start_date, end_date, dimensions, device_type=None):
    with st.spinner('Fetching data and calculating relevancy scores...'):
        df = fetch_gsc_data(webproperty, search_type, start_date, end_date, dimensions, device_type)
        if not df.empty:
            df = calculate_relevancy_scores(df)
        processed_df = process_gsc_data(df)
        return processed_df

# -------------
# Utility Functions
# -------------

def update_dimensions(selected_search_type):
    return BASE_DIMENSIONS + ['device'] if selected_search_type in SEARCH_TYPES else BASE_DIMENSIONS

def calc_date_range(selection, custom_start=None, custom_end=None):
    range_map = {
        'Last 7 Days': 7,
        'Last 30 Days': 30,
        'Last 3 Months': 90,
        'Last 6 Months': 180,
        'Last 12 Months': 365,
        'Last 16 Months': 480
    }
    today = datetime.date.today()
    if selection == 'Custom Range':
        if custom_start and custom_end:
            return custom_start, custom_end
        else:
            return today - datetime.timedelta(days=7), today
    return today - datetime.timedelta(days=range_map.get(selection, 0)), today

def show_error(e):
    st.error(f"An error occurred: {e}")

def property_change():
    st.session_state.selected_property = st.session_state['selected_property_selector']

# -------------
# File & Download Operations
# -------------

def show_dataframe(report):
    with st.expander("Preview the First 100 Rows (Unique Pages with Top Query)"):
        st.dataframe(report.head(DF_PREVIEW_ROWS))

def download_csv_link(report):
    def to_csv(df):
        return df.to_csv(index=False, encoding='utf-8-sig')
    csv = to_csv(report)
    b64_csv = base64.b64encode(csv.encode()).decode()
    href = f'<a href="data:file/csv;base64,{b64_csv}" download="search_console_data.csv">Download CSV File</a>'
    st.markdown(href, unsafe_allow_html=True)

# -------------
# Streamlit UI Components
# -------------

def show_google_sign_in(auth_url):
    with st.sidebar:
        if st.button("Sign in with Google"):
            st.write('Please click the link below to sign in:')
            st.markdown(f'[Google Sign-In]({auth_url})', unsafe_allow_html=True)

def show_property_selector(properties, account):
    selected_property = st.selectbox(
        "Select a Search Console Property:",
        properties,
        index=properties.index(
            st.session_state.selected_property) if st.session_state.selected_property in properties else 0,
        key='selected_property_selector',
        on_change=property_change
    )
    return account[selected_property]

def show_search_type_selector():
    return st.selectbox(
        "Select Search Type:",
        SEARCH_TYPES,
        index=SEARCH_TYPES.index(st.session_state.selected_search_type),
        key='search_type_selector'
    )

def show_date_range_selector():
    return st.selectbox(
        "Select Date Range:",
        DATE_RANGE_OPTIONS,
        index=DATE_RANGE_OPTIONS.index(st.session_state.selected_date_range),
        key='date_range_selector'
    )

def show_custom_date_inputs():
    st.session_state.custom_start_date = st.date_input("Start Date", st.session_state.custom_start_date)
    st.session_state.custom_end_date = st.date_input("End Date", st.session_state.custom_end_date)

def show_dimensions_selector(search_type):
    available_dimensions = update_dimensions(search_type)
    return st.multiselect(
        "Select Dimensions:",
        available_dimensions,
        default=st.session_state.selected_dimensions,
        key='dimensions_selector'
    )

def show_paginated_dataframe(report, rows_per_page=20):
    total_rows = len(report)
    total_pages = (total_rows - 1) // rows_per_page + 1

    if 'current_page' not in st.session_state:
        st.session_state.current_page = 1

    col1, col2, col3 = st.columns([1,3,1])
    with col1:
        if st.button("Previous", disabled=st.session_state.current_page == 1):
            st.session_state.current_page -= 1
    with col2:
        st.write(f"Page {st.session_state.current_page} of {total_pages}")
    with col3:
        if st.button("Next", disabled=st.session_state.current_page == total_pages):
            st.session_state.current_page += 1

    start_idx = (st.session_state.current_page - 1) * rows_per_page
    end_idx = start_idx + rows_per_page
    st.dataframe(report.iloc[start_idx:end_idx])

# -------------
# Main Streamlit App Function
# -------------

def main():
    setup_streamlit()
    client_config = load_config()
    
    if 'auth_flow' not in st.session_state or 'auth_url' not in st.session_state:
        st.session_state.auth_flow, st.session_state.auth_url = google_auth(client_config)

    query_params = st.experimental_get_query_params()
    auth_code = query_params.get("code", [None])[0]

    if auth_code and 'credentials' not in st.session_state:
        st.session_state.auth_flow.fetch_token(code=auth_code)
        st.session_state.credentials = st.session_state.auth_flow.credentials

    if 'credentials' not in st.session_state:
        show_google_sign_in(st.session_state.auth_url)
    else:
        init_session_state()
        account = auth_search_console(client_config, st.session_state.credentials)
        properties = list_gsc_properties(st.session_state.credentials)

        if properties:
            webproperty = show_property_selector(properties, account)
            search_type = show_search_type_selector()
            date_range_selection = show_date_range_selector()

            if date_range_selection == 'Custom Range':
                show_custom_date_inputs()
                start_date, end_date = st.session_state.custom_start_date, st.session_state.custom_end_date
            else:
                start_date, end_date = calc_date_range(date_range_selection)

            selected_dimensions = show_dimensions_selector(search_type)

            if 'report_data' not in st.session_state:
                st.session_state.report_data = None

            if st.button("Fetch Data"):
                with st.spinner('Fetching data...'):
                    st.session_state.report_data = fetch_data_loading(webproperty, search_type, start_date, end_date, selected_dimensions)

            if st.session_state.report_data is not None and not st.session_state.report_data.empty:
                show_paginated_dataframe(st.session_state.report_data)
                download_csv_link(st.session_state.report_data)
            elif st.session_state.report_data is not None:
                st.warning("No data found for the selected criteria.")

                
if __name__ == "__main__":
    main()