import json
import os
import pandas as pd
import requests
import threading
import streamlit as st
from datasets import load_dataset, load_metric

MODELS = ["CodeParrot", "InCoder", "CodeGen", "PolyCoder"]
GENERATION_MODELS = ["CodeParrot", "InCoder", "CodeGen"]


@st.cache()
def load_examples():
    with open("utils/examples.json", "r") as f:
        examples = json.load(f)
    return examples
    
    
def load_evaluation():
    # load task 2 of HumanEval and code_eval_metric
    os.environ["HF_ALLOW_CODE_EVAL"] = "1"
    human_eval = load_dataset("openai_humaneval")
    entry_point = f"check({human_eval['test'][2]['entry_point']})"
    test_func = "\n" + human_eval["test"][2]["test"] + "\n" + entry_point
    code_eval = load_metric("code_eval")
    return code_eval, test_func


def read_markdown(path):
    with open(path, "r") as f:
        output = f.read()
    st.markdown(output, unsafe_allow_html=True)


def generate_code(
    generations, model_name, gen_prompt, max_new_tokens, temperature, seed
):
    # call space using its API endpoint
    url = (
        f"https://hf.space/embed/codeparrot/{model_name.lower()}-subspace/+/api/predict/"
    )
    r = requests.post(
        url=url, json={"data": [gen_prompt, max_new_tokens, temperature, seed]}
    )
    generated_text = r.json()["data"][0]
    generations.append({model_name: generated_text})


def generate_code_threads(
    generations, models, gen_prompt, max_new_tokens, temperature, seed
):
    threads = []
    for model_name in models:
        # create the thread
        threads.append(
            threading.Thread(
                target=generate_code,
                args=(
                    generations,
                    model_name,
                    gen_prompt,
                    max_new_tokens,
                    temperature,
                    seed,
                ),
            )
        )
        threads[-1].start()

    for t in threads:
        t.join()

@st.cache(show_spinner=False)
def generate_teaser(gen_prompt):
    generations = []
    generate_code(generations, "CodeParrot", gen_prompt, 8, 0.2, 42)
    return generations[0]["CodeParrot"]
    
st.set_page_config(page_icon=":laptop:", layout="wide")
with open("utils/table_contents.md", "r") as f:
    contents = f.read()

st.sidebar.markdown(contents)

# Introduction
st.title("Code generation with 🤗")
read_markdown("utils/summary.md")
## teaser
example_text = "def print_hello_world():"
col1, col2, col3 = st.columns([1, 2, 1])
with col2:
    gen_prompt = st.text_area(
        "",
        value=example_text,
        height=100,
    ).strip()
    if st.button("Generate code!", key=1):
        with st.spinner("Generating code..."):
            st.code(generate_teaser(gen_prompt)) 
read_markdown("utils/intro.md")

# Code datasets
st.subheader("1 - Code datasets")
read_markdown("datasets/intro.md")
read_markdown("datasets/github_code.md")
col1, col2 = st.columns([1, 2])
with col1:
    selected_model = st.selectbox("", MODELS, key=1)
read_markdown(f"datasets/{selected_model.lower()}.md")


# Model architecture
st.subheader("2 - Model architecture")
read_markdown("architectures/intro.md")
col1, col2 = st.columns([1, 2])
with col1:
    selected_model = st.selectbox("", MODELS, key=2)
read_markdown(f"architectures/{selected_model.lower()}.md")

# Model evaluation
st.subheader("3 - Code model evaluation")
read_markdown("evaluation/intro.md")
read_markdown("evaluation/demo_humaneval.md")
## quiz
st.markdown("Below you can try solving this problem or visualize the solution of CodeParrot:")
with open("evaluation/problem.md", "r") as f:
    problem = f.read()
with open("evaluation/solution.md", "r") as f:
    solution = f.read()
    
candidate_solution = st.text_area(
    "Complete the problem:",
    value=problem,
    height=240,
).strip()
if st.button("Test my solution", key=2):
    with st.spinner("Testing..."):
        code_eval, test_func = load_evaluation()
        test_cases = [test_func]
        candidates = [[candidate_solution]]
        pass_at_k, _ = code_eval.compute(references=test_cases, predictions=candidates)
        text = "Your solution didn't pass the test, pass@1 is 0 😕" if pass_at_k['pass@1'] < 1  else "Congrats your pass@1 is 1! 🎉"
        st.markdown(text)
if st.button("Show model solution", key=3):
    st.markdown(solution)
    
# Code generation
st.subheader("4 - Code generation ✨")
read_markdown("generation/intro.md")
col1, col2, col3 = st.columns([7, 1, 6])
with col1:
    st.markdown("**Models**")
    selected_models = st.multiselect(
        "Select code generation models to compare:",
        GENERATION_MODELS,
        default=GENERATION_MODELS,
        key=3,
    )
    st.markdown(" ")
    st.markdown("**Examples**")
    examples = load_examples()
    example_names = [example["name"] for example in examples]
    name2id = dict([(name, i) for i, name in enumerate(example_names)])
    selected_example = st.selectbox(
        "Select one of the following examples or implement yours:", example_names
    )
    example_text = examples[name2id[selected_example]]["value"]
    default_length = examples[name2id[selected_example]]["length"]
with col3:
    st.markdown("**Generation settings**")
    temperature = st.slider(
        "Temperature:", value=0.2, min_value=0.1, step=0.1, max_value=2.0
    )
    max_new_tokens = st.slider(
        "Number of tokens to generate:",
        value=default_length,
        min_value=8,
        step=4,
        max_value=256,
    )
    seed = st.slider("Random seed:", value=42, min_value=0, step=1, max_value=1000)
gen_prompt = st.text_area(
    "Generate code with prompt:",
    value=example_text,
    height=200,
).strip()
if st.button("Generate code!", key=4):
    with st.spinner("Generating code..."):
        # use threading
        generations = []
        generate_code_threads(
            generations,
            selected_models,
            gen_prompt=gen_prompt,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            seed=seed,
        )
        for i in range(len(generations)):
            st.markdown(f"**{selected_models[i]}**")
            for j in range(len(generations)):
                if selected_models[i] in generations[j].keys():
                    st.code(generations[j][selected_models[i]])
        if len(generations) < len(selected_models):
            st.markdown("<span style='color:red'>Warning: Some models run into timeout, try another time or reduce the Number of tokens to generate. You can also try generating code using the original subspaces: [InCoder](https://huggingface.co/spaces/loubnabnl/incoder-subspace), [CodeGen](https://huggingface.co/spaces/loubnabnl/codegen-subspace), [CodeParrot](https://huggingface.co/spaces/loubnabnl/codeparrot-subspace)</span>", unsafe_allow_html=True)

# Resources
st.subheader("Resources")
read_markdown("utils/resources.md")