import os
import re
import cv2
import time
import bcrypt
import numpy as np
import streamlit as st
import tensorflow as tf
from pymongo import MongoClient

def load_data():
    client = MongoClient(os.environ.get("MONGODB_URL"))
    db = client['medical-app-auth']
    collection = db['login-data']
    brain_tumor_model = tf.keras.models.load_model('models/brain_tumor.h5')
    alzheimer_model = tf.keras.models.load_model('models/alzheimer.h5')
    return collection, brain_tumor_model, alzheimer_model

collection, brain_tumor_model, alzheimer_model = load_data()

if 'current_page' not in st.session_state:
    st.session_state.current_user = None
    st.session_state.current_page = 'login'
    
def clear_cache():
        keys = list(st.session_state.keys())
        for key in keys:
            st.session_state.pop(key)

def login():
    st.set_page_config(layout='centered', page_title="Brain MRI", page_icon="mri_of_brain.jpg")
    col1, col2 = st.columns([5,1])
    with col2:
        if st.button("Register", use_container_width=True):
            st.session_state.current_page = 'register'
            st.rerun()

    def reset_passowrd_input():
        st.session_state.password   = ""    

    def reset_username_inputs():
        st.session_state.username   = ""
        reset_passowrd_input()

    with st.form(key='login', clear_on_submit=True):
        st.subheader("Login")

        username = st.text_input("Username", placeholder="Enter Username")
        password = st.text_input("Enter Password", type="password")
        username = username.lower().strip()
        submit = st.form_submit_button("Login")
        if submit:
            with st.spinner('Checking credentials...'):
                user = collection.find_one({"username":username})
                if user==None:
                    st.warning("Username is does not exits, please register")
                    reset_username_inputs()
                elif bcrypt.checkpw(password.encode('utf-8'), user.get("password")):
                    st.warning("**Credential Matched**: Redirecting...") 
                    st.session_state.current_user = username
                    st.session_state.current_page = 'medical'
                    st.rerun()
                else:
                    st.warning("Password is incorrect")
                    reset_passowrd_input()

def register():
    st.set_page_config(layout='centered', page_title="Brain MRI", page_icon="mri_of_brain.jpg")
    col1, col2 = st.columns([5,1])
    with col2:
        if st.button("Login", use_container_width=True):
            st.session_state.current_page = 'login'
            st.rerun()

    document = {}
    pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'

    def reset_password_inputs():
        st.session_state.password1  = ""
        st.session_state.password2  = ""

    def reset_username_inputs():
        st.session_state.username   = ""
        reset_password_inputs()

    def reset_email_inputs():
        st.session_state.email      = ""
        reset_username_inputs()

    with st.form(key='register', clear_on_submit=True):
        st.subheader("Register")

        email = st.text_input("Email", placeholder="enter@your.email", max_chars=64)
        username = st.text_input("Username", placeholder="Enter Username", max_chars=17)
        name = st.text_input("Name", placeholder="Enter Your Full Name")
        password1 = st.text_input("Enter Password", type="password", max_chars=17)
        password2 = st.text_input("Confirm Password", type="password", max_chars=17)
        
        email = email.lower().strip()
        username = username.lower().strip()
        name = name.title().strip()
        
        submit = st.form_submit_button("Register")
        if submit:
            with st.spinner('Checking credentials...'):
                if re.match(pattern, email)==None:
                    st.warning(":red[Please provide a valid email.]", icon="⚠️")
                    reset_email_inputs()
                elif collection.find_one({"email":email}):
                    st.warning("Email already exits, please Login!")
                    reset_email_inputs()
                elif not name.replace(" ", "").isalpha():
                    st.warning("Don't use number or special characters for name")
                    st.session_state.name = ""
                elif len(username)<5:
                    st.warning("Username must atleast be of 5 characters")
                    reset_username_inputs()
                elif len(password1)<6:
                    st.warning("Password must atleast be of 6 characters")
                    reset_password_inputs()
                elif collection.find_one({"username": username}):
                    st.warning("Username already exits, please try a different one.")
                    reset_username_inputs()
                elif password1 != password2:
                    st.warning(":red[Passwords do not match. Please try again.]", icon="⚠️")
                    reset_password_inputs()
                elif not (email and name and username and password1 and password2):
                    st.warning(":red[Please complete all the fields above.]", icon="⚠️")
                else:
                    salt = bcrypt.gensalt(rounds=13)
                    hashed_password = bcrypt.hashpw(password1.encode('utf-8'), salt)
                    salt = bcrypt.gensalt(rounds=13)
                    hashed_password = bcrypt.hashpw(password1.encode('utf-8'), salt)
                    document = {"name":name, "username":username, "email":email, "password":hashed_password, "salt":salt}
                    collection.insert_one(document)
                    st.warning("**Successfully Registered**: Redirecting...")
                    st.session_state.current_user = username
                    st.session_state.current_page = 'medical'
                    st.rerun()



def medical_page():
    def alzheimer():
        col1, col2, col3 = st.columns([6, 6, 1])
        with col1:
            if st.button(f"Welcome! {st.session_state.current_user}"):
                st.rerun()
        with col3:
            st.button("Logout", use_container_width=True, on_click=clear_cache)
                
        st.markdown("***")         
        st.subheader("Here's your Alzheimer's Scan")
        
        uploaded_file = st.file_uploader("Upload MRI scan image for detecting alzheimer's", type=['png', 'jpg'])
        if uploaded_file is not None:
            file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
            opencv_image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
            resized_img = cv2.resize(opencv_image, (208, 176))
            resized_img = resized_img[np.newaxis, ...]
            class_names = ['ModerateDemented', 'NonDemented', 'VeryMildDemented', 'MildDemented']
            
            pred = class_names[np.argmax(alzheimer_model.predict(resized_img))]
            st.text(pred)

    def brain_tumor():
        col1, col2, col3 = st.columns([6, 6, 1])
        with col1:
            if st.button(f"Welcome! {st.session_state.current_user}"):
                st.rerun()
        with col3:
            st.button("Logout", use_container_width=True, on_click=clear_cache)
                
        st.markdown("***")        
        st.subheader("Here's your Brain Tumor Scan")
        
        uploaded_file = st.file_uploader("Upload MRI scan image for detecting Brain Tumor", type=['png', 'jpg'])
        if uploaded_file is not None:
            file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
            opencv_image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
            resized_img = cv2.resize(opencv_image, (168, 150))
            resized_img = resized_img[np.newaxis, ...]
            class_names = ['Pituitary', 'No-Tumor', 'Meningioma', 'Glioma']
            
            pred = class_names[np.argmax(brain_tumor_model.predict(resized_img))]
            st.text(pred)

    def abscesses():
        col1, col2, col3 = st.columns([6, 6, 1])
        with col1:
            if st.button(f"Welcome! {st.session_state.current_user}"):
                st.rerun()
        with col3:
            st.button("Logout", use_container_width=True, on_click=clear_cache)
        st.markdown("***")        
        st.subheader("Here's your Abscesses Scan")
        uploaded_file = st.file_uploader("Upload MRI scan image for detecting Abscesses", type=['png', 'jpg'], disabled=True)
        st.write("Feature currently unavailable")
        if uploaded_file is not None:
            file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
            opencv_image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
            resized_img = cv2.resize(opencv_image, (208, 176))
            resized_img = resized_img[np.newaxis, ...]
            class_names = ['ModerateDemented', 'NonDemented', 'VeryMildDemented', 'MildDemented']
            pred = class_names[np.argmax(abscesses_model.predict(resized_img))]
            st.text(pred)
    st.set_page_config(layout='wide', page_title="Brain MRI", page_icon="mri_of_brain.jpg")
    st.sidebar.image("mri_of_brain.jpg")
    st.sidebar.title("Navigation")
    page_options = ["Alzheimer", "Brain Tumor", "Abscesses"]
    selected_page = st.sidebar.selectbox("Select a Scan", page_options)

    if selected_page == "Alzheimer":
        alzheimer()
    elif selected_page == "Brain Tumor":
        brain_tumor()
    elif selected_page == "Abscesses":
        abscesses()


def main():
    if st.session_state.current_page == 'login':
        login()
    elif st.session_state.current_page == 'register':
        register()
    elif st.session_state.current_page == 'medical':
        medical_page()

if __name__ == "__main__":

    main()