import logging from pathlib import Path import os import re import gradio as gr import nltk import torch from cleantext import clean from summarize import load_model_and_tokenizer, summarize_via_tokenbatches _here = Path(__file__).parent nltk.download("stopwords") # TODO=find where this requirement originates from import transformers transformers.logging.set_verbosity_error() logging.basicConfig() def truncate_word_count(text, max_words=512): """ truncate_word_count - a helper function for the gradio module Parameters ---------- text : str, required, the text to be processed max_words : int, optional, the maximum number of words, default=512 Returns ------- dict, the text and whether it was truncated """ # split on whitespace with regex words = re.split(r"\s+", text) processed = {} if len(words) > max_words: processed["was_truncated"] = True processed["truncated_text"] = " ".join(words[:max_words]) else: processed["was_truncated"] = False processed["truncated_text"] = text return processed def proc_submission( input_text: str, num_beams, token_batch_length, length_penalty, repetition_penalty, no_repeat_ngram_size, max_input_length: int = 512, ): """ proc_submission - a helper function for the gradio module Parameters ---------- input_text : str, required, the text to be processed max_input_length : int, optional, the maximum length of the input text, default=512 Returns ------- str of HTML, the interactive HTML form for the model """ settings = { "length_penalty": length_penalty, "repetition_penalty": repetition_penalty, "no_repeat_ngram_size": no_repeat_ngram_size, "encoder_no_repeat_ngram_size": 4, "num_beams": num_beams, "early_stopping": True, "min_length": 10, "do_sample": False, } history = {} clean_text = clean(input_text, lower=False) processed = truncate_word_count(clean_text, max_input_length) if processed["was_truncated"]: tr_in = processed["truncated_text"] history["was_truncated"] = True msg = f"Input text was truncated to {max_input_length} characters." logging.warning(msg) history["WARNING"] = msg else: tr_in = input_text history["was_truncated"] = False _summaries = summarize_via_tokenbatches( tr_in, model, tokenizer, batch_length=token_batch_length, **settings, ) sum_text = [s["summary"][0] for s in _summaries] sum_scores = [f"\n - {round(s['summary_score'],4)}" for s in _summaries] history["Summary Text"] = "\n\t".join(sum_text) history["Summary Scores"] = "\n".join(sum_scores) history["Input"] = tr_in html = "" for name, item in history.items(): html += ( f"

{name}:


{item}

" if "summary" not in name.lower() else f"

{name}:


{item}" ) html += "" return html def load_examples(examples_dir="examples"): src = _here / examples_dir src.mkdir(exist_ok=True) examples = [f for f in src.glob("*.txt")] # load the examples into a list text_examples = [] for example in examples: with open(example, "r") as f: text = f.read() text_examples.append([text, 2, 1024, 0.7, 3.5, 3]) return text_examples if __name__ == "__main__": model, tokenizer = load_model_and_tokenizer("pszemraj/led-large-book-summary") title = "Long-form Summarization: LED & BookSum" description = ( "This is a simple example of using the LED model to summarize a long-form text. This model is a fine-tuned version of [allenai/led-large-16384](https://huggingface.co/allenai/led-large-16384) on the booksum dataset. the goal was to create a model that can generalize well and is useful in summarizing lots of text in academic and daily usage." ) gr.Interface( proc_submission, inputs=[ gr.inputs.Textbox(lines=10, label="input text"), gr.inputs.Slider( minimum=1, maximum=6, label="num_beams", default=4, step=1 ), gr.inputs.Slider( minimum=512, maximum=2048, label="token_batch_length", default=1024, step=512, ), gr.inputs.Slider( minimum=0.5, maximum=1.1, label="length_penalty", default=0.7, step=0.05 ), gr.inputs.Slider( minimum=1.0, maximum=5.0, label="repetition_penalty", default=3.5, step=0.1, ), gr.inputs.Slider( minimum=2, maximum=4, label="no_repeat_ngram_size", default=3, step=1 ), ], outputs="html", examples_per_page=4, title=title, description=description, examples=load_examples(), cache_examples=False, ).launch(enable_queue=True, )