import gradio as gr
import pandas as pd
from PIL import Image
from rdkit import RDLogger
from molecule_generation_helpers import *
from property_prediction_helpers import *
RDLogger.logger().setLevel(RDLogger.ERROR)
# Predefined dataset paths (these should be adjusted to your file paths)
predefined_datasets = {
    " ": " ",
    "BACE": f"./data/bace/train.csv, ./data/bace/test.csv, smiles, Class",
    "ESOL": f"./data/esol/train.csv, ./data/esol/test.csv, smiles, prop",
}
# Models
models_enabled = ["SELFIES-TED", "MHG-GED", "MolFormer", "SMI-TED"]
# Fusion Types
fusion_available = ["Concat"]
# Function to load a predefined dataset from the local path
def load_predefined_dataset(dataset_name):
    val = predefined_datasets.get(dataset_name)
    if val:
        df = pd.read_csv(val.split(",")[0])
        return (
            df.head(),
            gr.update(choices=list(df.columns)),
            gr.update(choices=list(df.columns)),
            dataset_name.lower(),
        )
    else:
        return (
            pd.DataFrame(),
            gr.update(choices=[]),
            gr.update(choices=[]),
            f"Dataset not found",
        )
# Function to handle dataset selection (predefined or custom)
def handle_dataset_selection(selected_dataset):
    if selected_dataset == "Custom Dataset":
        # Show file upload fields for train and test datasets if "Custom Dataset" is selected
        return (
            gr.update(visible=True),
            gr.update(visible=True),
            gr.update(visible=True),
            gr.update(visible=True),
            gr.update(visible=True),
            gr.update(visible=False),
            gr.update(visible=True),
            gr.update(visible=True),
        )
    return (
        gr.update(visible=True),
        gr.update(visible=False),
        gr.update(visible=False),
        gr.update(visible=False),
        gr.update(visible=False),
        gr.update(visible=False),
        gr.update(visible=False),
        gr.update(visible=False),
    )
# Dynamically show relevant hyperparameters based on selected model
def update_hyperparameters(model_name):
    if model_name == "XGBClassifier":
        return (
            gr.update(visible=True),
            gr.update(visible=True),
            gr.update(visible=True),
            gr.update(visible=False),
            gr.update(visible=False),
        )
    elif model_name == "SVR":
        return (
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=True),
            gr.update(visible=True),
        )
    elif model_name == "Kernel Ridge":
        return (
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=True),
            gr.update(visible=True),
            gr.update(visible=True),
        )
    elif model_name == "Linear Regression":
        return (
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
        )
    elif model_name == "Default - Auto":
        return (
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
        )
# Function to select input and output columns and display a message
def select_columns(input_column, output_column, train_data, test_data, dataset_name):
    if input_column and output_column:
        return f"{train_data.name},{test_data.name},{input_column},{output_column},{dataset_name}"
    return "Please select both input and output columns."
# Function to set Dataset Name
def set_dataname(dataset_name, dataset_selector):
    return dataset_name if dataset_selector == "Custom Dataset" else dataset_selector
# Function to display the head of the uploaded CSV file
def display_csv_head(file):
    if file is not None:
        # Load the CSV file into a DataFrame
        df = pd.read_csv(file.name)
        return (
            df.head(),
            gr.update(choices=list(df.columns)),
            gr.update(choices=list(df.columns)),
        )
    return pd.DataFrame(), gr.update(choices=[]), gr.update(choices=[])
# Dictionary for SMILES strings and corresponding images (you can replace with your actual image paths)
smiles_image_mapping = {
    # Example SMILES for ethanol
    "Mol 1": {
        "smiles": "C=C(C)CC(=O)NC[C@H](CO)NC(=O)C=Cc1ccc(C)c(Cl)c1",
        "image": "img/img1.png",
    },
    # Example SMILES for butane
    "Mol 2": {
        "smiles": "C=CC1(CC(=O)NC[C@@H](CCCC)NC(=O)c2cc(Cl)cc(Br)c2)CC1",
        "image": "img/img2.png",
    },
    # Example SMILES for ethylamine
    "Mol 3": {
        "smiles": "C=C(C)C[C@H](NC(C)=O)C(=O)N1CC[C@H](NC(=O)[C@H]2C[C@@]2(C)Br)C(C)(C)C1",
        "image": "img/img3.png",
    },
    # Example SMILES for diethyl ether
    "Mol 4": {
        "smiles": "C=C1CC(CC(=O)N[C@H]2CCN(C(=O)c3ncccc3SC)C23CC3)C1",
        "image": "img/img4.png",
    },
    # Example SMILES for chloroethane
    "Mol 5": {
        "smiles": "C=CCS[C@@H](C)CC(=O)OCC",
        "image": "img/img5.png",
    },
}
# Load images for selection
def load_image(path):
    try:
        return Image.open(smiles_image_mapping[path]["image"])
    except:
        pass
# Function to handle image selection
def handle_image_selection(image_key):
    if not image_key:
        return None, None
    smiles = smiles_image_mapping[image_key]["smiles"]
    mol_image = smiles_to_image(smiles)
    return smiles, mol_image
# Introduction
with open("INTRODUCTION.md") as f:
    # introduction = gr.Markdown(f.read())
    with gr.Blocks() as introduction:
        gr.Markdown(f.read())
        gr.Markdown("---\n# Debug")
        gr.HTML("HTML text:  ")
        gr.Markdown("Markdown text: ")
        gr.HTML("HTML text:
")
        gr.Markdown("Markdown text: ")
        gr.HTML("HTML text: ")
        gr.Markdown("Markdown text: ")
# Property Prediction
with gr.Blocks() as property_prediction:
    log_df = pd.DataFrame(
        {"": [], 'Selected Models': [], 'Dataset': [], 'Task': [], 'Result': []}
    )
    state = gr.State({"log_df": log_df})
    gr.HTML(
        '''
    
        Task : Property Prediction
        
        Models are finetuned with different combination of modalities on the uploaded or selected built data set.
    
        Task : Molecule Generation
        
        Generate a new molecule similar to the initial molecule with better drug-likeness and synthetic accessibility.