import os
import random
import uuid
from datetime import datetime
from difflib import ndiff

import gradio as gr

from data_loader import load_data

HF_TOKEN = os.environ.get('HF_TOKEN')
HF_DATASET = os.environ.get('HF_DATASET')

data = load_data()

n_samples = len(data)

saver = gr.HuggingFaceDatasetSaver(HF_TOKEN, HF_DATASET, private=True)


def convert_diff_to_unified(diff):
    result = "\n".join(
        [
            f'--- {modified_file["old_path"]}\n'
            f'+++ {modified_file["new_path"]}\n'
            f'{modified_file["diff"]}'
            for modified_file in diff
        ]
    )

    return result


def get_diff2html_view(raw_diff):
    html = f"""
    <div style='width:100%; height:1400px; overflow:auto; position: relative'>
        <div id='diff-raw' hidden>{raw_diff}</div> 
        <div class="d2h-view-wrapper">
            <div id='diff-view'></div>
        </div>
    </div>
    """

    return html


def get_github_link_md(repo, hash):
    return f'[See the commit on Github](https://github.com/{repo}/commit/{hash})'


def char_diff_obj(change_type, pos, character, timestamp):
    return {"type": change_type, "pos": pos, "char": character, "timestamp": timestamp}


def update_commit_view(sample_ind):
    if sample_ind >= n_samples:
        return None

    record = data[sample_ind]

    diff_view = get_diff2html_view(convert_diff_to_unified(record['mods']))

    repo_val = record['repo']
    hash_val = record['hash']
    github_link_md = get_github_link_md(repo_val, hash_val)

    diff_loaded_timestamp = datetime.now().isoformat()

    commit_message = record['prediction']
    commit_message_start = commit_message
    commit_message_prev = commit_message
    commit_message_history = []

    return (
        github_link_md, diff_view, repo_val, hash_val, diff_loaded_timestamp,
        commit_message_start, commit_message, commit_message_prev, commit_message_history)


def next_sample(current_sample_ind, shuffled_idx):
    if current_sample_ind == n_samples:
        return None

    current_sample_ind += 1
    updated_view = update_commit_view(shuffled_idx[current_sample_ind])
    return (current_sample_ind,) + updated_view


with open("head.html") as head_file:
    head_html = head_file.read()

with gr.Blocks(theme=gr.themes.Soft(), head=head_html, css="style_overrides.css") as application:
    repo_val = gr.Textbox(interactive=False, label='repo', visible=False)
    hash_val = gr.Textbox(interactive=False, label='hash', visible=False)
    shuffled_idx_val = gr.JSON(visible=False)

    with gr.Row():
        with gr.Accordion("Help"):
            with open("survey_guide.md") as content_file:
                gr.Markdown(content_file.read())

    with gr.Row():
        current_sample_sld = gr.Slider(minimum=0, maximum=n_samples, step=1,
                                       value=0,
                                       interactive=False,
                                       label='sample_ind',
                                       info=f"Samples labeled/skipped (out of {n_samples})",
                                       show_label=False,
                                       container=False,
                                       scale=5)
        with gr.Column(scale=1):
            skip_btn = gr.Button("Skip the current sample")
    with gr.Row():
        with gr.Column(scale=2):
            github_link = gr.Markdown()
            diff_view = gr.HTML()
        with gr.Column(scale=1):
            commit_msg_start = gr.TextArea(label="commit_msg_start", visible=False)
            commit_msg = gr.TextArea(label="commit_msg_end", show_label=False,
                                     info="Commit message (can be scrollable)")
            commit_msg_prev = gr.TextArea(visible=False)
            commit_msg_history = gr.JSON(label="commit_msg_history", visible=False)

            submit_btn = gr.Button("Submit")

            session_val = gr.Textbox(info='Session', interactive=False, container=True, show_label=False,
                                     label='session')

            with gr.Row(visible=False):
                sample_loaded_timestamp = gr.Textbox(info="Sample loaded", label='loaded_ts', interactive=False,
                                                     container=True, show_label=False)
                now_timestamp = gr.Textbox(info="Current time",
                                           interactive=False, container=True, show_label=False,
                                           value=lambda: datetime.now().isoformat(), every=1.0,
                                           label='submitted_ts')

    commit_view = [
        github_link,
        diff_view,
        repo_val,
        hash_val,
        sample_loaded_timestamp,
        commit_msg_start,
        commit_msg,
        commit_msg_prev,
        commit_msg_history
    ]

    feedback_metadata = [
        session_val,
        repo_val,
        hash_val,
        sample_loaded_timestamp,
        now_timestamp
    ]

    feedback_form = [
        # commit_msg_start,
        commit_msg,
        commit_msg_history
    ]

    saver.setup([current_sample_sld] + feedback_metadata + feedback_form, "feedback")

    skip_btn.click(next_sample, inputs=[current_sample_sld, shuffled_idx_val],
                   outputs=[current_sample_sld] + commit_view)


    def submit(current_sample, shuffled_idx, *args):
        saver.flag((current_sample,) + args)
        return next_sample(current_sample, shuffled_idx)


    submit_btn.click(
        submit,
        inputs=[current_sample_sld, shuffled_idx_val] + feedback_metadata + feedback_form,
        outputs=[current_sample_sld] + commit_view
    )


    def on_commit_msg_changed(message, prev_message, history, timestamp):
        for i, s in enumerate(ndiff(prev_message, message)):
            diff = char_diff_obj(s[0], i, s[-1], timestamp)
            if diff['type'] in ('+', '-'):
                history.append(diff)
        return message, history


    commit_msg.change(on_commit_msg_changed, inputs=[commit_msg, commit_msg_prev, commit_msg_history,
                                                     now_timestamp],
                      outputs=[commit_msg_prev, commit_msg_history])


    def init_session(current_sample):
        session = str(uuid.uuid4())
        shuffled_idx = list(range(n_samples))
        random.shuffle(shuffled_idx)
        return (session, shuffled_idx) + update_commit_view(shuffled_idx[current_sample])


    application.load(init_session,
                     inputs=[current_sample_sld],
                     outputs=[session_val, shuffled_idx_val] + commit_view, )

application.launch()