import logging
import re
from pathlib import Path

import gradio as gr
import nltk
from cleantext import clean

from summarize import load_model_and_tokenizer, summarize_via_tokenbatches

_here = Path(__file__).parent

nltk.download("stopwords")  # TODO=find where this requirement originates from

import transformers

transformers.logging.set_verbosity_error()
logging.basicConfig()


def truncate_word_count(text, max_words=512):
    """
    truncate_word_count - a helper function for the gradio module
    Parameters
    ----------
    text : str, required, the text to be processed
    max_words : int, optional, the maximum number of words, default=512
    Returns
    -------
    dict, the text and whether it was truncated
    """
    # split on whitespace with regex
    words = re.split(r"\s+", text)
    processed = {}
    if len(words) > max_words:
        processed["was_truncated"] = True
        processed["truncated_text"] = " ".join(words[:max_words])
    else:
        processed["was_truncated"] = False
        processed["truncated_text"] = text
    return processed


def proc_submission(
    input_text: str,
    model_size: str,
    num_beams,
    token_batch_length,
    length_penalty,
    repetition_penalty,
    no_repeat_ngram_size,
    max_input_length: int = 512,
):
    """
    proc_submission - a helper function for the gradio module
    Parameters
    ----------
    input_text : str, required, the text to be processed
    max_input_length : int, optional, the maximum length of the input text, default=512
    Returns
    -------
    str of HTML, the interactive HTML form for the model
    """

    settings = {
        "length_penalty": length_penalty,
        "repetition_penalty": repetition_penalty,
        "no_repeat_ngram_size": no_repeat_ngram_size,
        "encoder_no_repeat_ngram_size": 4,
        "num_beams": num_beams,
        "min_length": 4,
        "max_length": int(token_batch_length // 4),
        "early_stopping": True,
        "do_sample": False,
    }

    history = {}
    clean_text = clean(input_text, lower=False)
    max_input_length = 1024 if model_size == "base" else max_input_length
    processed = truncate_word_count(clean_text, max_input_length)
    if processed["was_truncated"]:
        tr_in = processed["truncated_text"]
        msg = f"Input text was truncated to {max_input_length} words (based on whitespace)"
        logging.warning(msg)
        history["WARNING"] = msg
    else:
        tr_in = input_text
        history["was_truncated"] = False

    _summaries = summarize_via_tokenbatches(
        tr_in,
        model_sm if model_size == "base" else model,
        tokenizer_sm if model_size == "base" else tokenizer,
        batch_length=token_batch_length,
        **settings,
    )
    sum_text = [f"Section {i}: " + s["summary"][0] for i, s in enumerate(_summaries)]
    sum_scores = [
        f"\n - Section {i}: {round(s['summary_score'],4)}"
        for i, s in enumerate(_summaries)
    ]

    history["Summary Text"] = "<br>".join(sum_text)
    history["Summary Scores"] = "\n".join(sum_scores)
    history["Input"] = tr_in
    html = ""

    for name, item in history.items():
        html += (
            f"<h2>{name}:</h2><hr><b>{item}</b><br><br>"
            if "summary" not in name.lower()
            else f"<h2>{name}:</h2><hr>{item}<br><br>"
        )

    html += ""

    return html


def load_examples(examples_dir="examples"):
    """
    load_examples - a helper function for the gradio module to load examples
    Returns:
        list of str, the examples
    """
    src = _here / examples_dir
    src.mkdir(exist_ok=True)
    examples = [f for f in src.glob("*.txt")]
    # load the examples into a list
    text_examples = []
    for example in examples:
        with open(example, "r") as f:
            text = f.read()
            text_examples.append([text, "large", 2, 512, 0.7, 3.5, 3])

    return text_examples


if __name__ == "__main__":

    model, tokenizer = load_model_and_tokenizer("pszemraj/led-large-book-summary")
    model_sm, tokenizer_sm = load_model_and_tokenizer("pszemraj/led-base-book-summary")
    title = "Long-Form Summarization: LED & BookSum"
    description = "A simple demo of how to use a fine-tuned LED model to summarize long-form text. [This model](https://huggingface.co/pszemraj/led-large-book-summary) is a fine-tuned version of [allenai/led-large-16384](https://huggingface.co/allenai/led-large-16384) on the [BookSum dataset](https://arxiv.org/abs/2105.08209). The goal was to create a model that can generalize well and is useful in summarizing lots of text in academic and daily usage."

    gr.Interface(
        proc_submission,
        inputs=[
            gr.inputs.Textbox(
                lines=10,
                label="input text",
                placeholder="Enter text to summarize, the text will be cleaned and truncated on Spaces. Narrative, academic (both papers and lecture transcription), and article text work well. May take a bit to generate depending on the input text :)",
            ),
            gr.inputs.Radio(
                choices=["base", "large"], label="model size", default="base"
            ),
            gr.inputs.Slider(
                minimum=2, maximum=4, label="num_beams", default=2, step=1
            ),
            gr.inputs.Slider(
                minimum=512,
                maximum=1024,
                label="token_batch_length",
                default=512,
                step=256,
            ),
            gr.inputs.Slider(
                minimum=0.5, maximum=1.1, label="length_penalty", default=0.7, step=0.05
            ),
            gr.inputs.Slider(
                minimum=1.0,
                maximum=5.0,
                label="repetition_penalty",
                default=3.5,
                step=0.1,
            ),
            gr.inputs.Slider(
                minimum=2, maximum=4, label="no_repeat_ngram_size", default=3, step=1
            ),
        ],
        outputs="html",
        examples_per_page=2,
        title=title,
        description=description,
        article="The model can be used with tag [pszemraj/led-large-book-summary](https://huggingface.co/pszemraj/led-large-book-summary). See the model card for details on usage & a notebook for a tutorial.",
        examples=load_examples(),
        cache_examples=True,
    ).launch()