""" app.py - the main application file for the gradio app """ import gc import logging import random import re import time from pathlib import Path import gradio as gr import nltk import torch from cleantext import clean from summarize import load_model_and_tokenizer, summarize_via_tokenbatches from utils import load_example_filenames, truncate_word_count _here = Path(__file__).parent nltk.download("stopwords", quiet=True) logging.basicConfig( level=logging.INFO, format="%(asctime)s - [%(levelname)s] %(name)s: %(message)s" ) MODEL_OPTIONS = [ "pszemraj/led-large-book-summary", "pszemraj/led-base-book-summary", ] def predict( input_text: str, model_name: str, token_batch_length: int = 2048, empty_cache: bool = True, **settings, ) -> list: """ predict - helper fn to support multiple models for summarization at once :param str input_text: the input text to summarize :param str model_name: model name to use :param int token_batch_length: the length of the token batches to use :param bool empty_cache: whether to empty the cache before loading a new= model :return: list of dicts with keys "summary" and "score" """ if torch.cuda.is_available() and empty_cache: torch.cuda.empty_cache() model, tokenizer = load_model_and_tokenizer(model_name) summaries = summarize_via_tokenbatches( input_text, model, tokenizer, batch_length=token_batch_length, **settings, ) del model del tokenizer gc.collect() return summaries def proc_submission( input_text: str, model_name: str, num_beams: int, token_batch_length: int, length_penalty: float, repetition_penalty: float, no_repeat_ngram_size: int, max_input_length: int = 2560, ): """ proc_submission - a helper function for the gradio module to process submissions Args: input_text (str): the input text to summarize model_size (str): the size of the model to use num_beams (int): the number of beams to use token_batch_length (int): the length of the token batches to use length_penalty (float): the length penalty to use repetition_penalty (float): the repetition penalty to use no_repeat_ngram_size (int): the no-repeat ngram size to use max_input_length (int, optional): the maximum input length to use. Defaults to 2560. Returns: str in HTML format, string of the summary, str of score """ logger = logging.getLogger(__name__) logger.info("Processing submission") settings = { "length_penalty": float(length_penalty), "repetition_penalty": float(repetition_penalty), "no_repeat_ngram_size": int(no_repeat_ngram_size), "encoder_no_repeat_ngram_size": 4, "num_beams": int(num_beams), "min_length": 4, "max_length": int(token_batch_length // 4), "early_stopping": True, "do_sample": False, } if "base" in model_name: logger.info("Updating max_input_length to for base model") max_input_length = 4096 logger.info(f"max_input_length: {max_input_length}") st = time.perf_counter() history = {} clean_text = clean(input_text, lower=False) processed = truncate_word_count(clean_text, max_input_length) if processed["was_truncated"]: truncated_input = processed["truncated_text"] # create elaborate HTML warning input_wc = re.split(r"\s+", input_text) msg = f"""

Warning

Input text was truncated to {max_input_length} words. That's about {100*max_input_length/len(input_wc):.2f}% of the submission.

""" logging.warning(msg) history["WARNING"] = msg else: truncated_input = input_text msg = None if len(input_text) < 50: # this is essentially a different case from the above msg = f"""

Error

Input text is too short to summarize. Detected {len(input_text)} characters. Please load text by selecting an example from the dropdown menu or by pasting text into the text box.

""" logging.warning(msg) logging.warning("RETURNING EMPTY STRING") history["WARNING"] = msg return msg, "", [] _summaries = predict( input_text=truncated_input, model_name=model_name, token_batch_length=token_batch_length, **settings, ) sum_text = [ f"\nBatch {i}:\n\t" + s["summary"][0] for i, s in enumerate(_summaries, start=1) ] sum_scores = [ f"\n- Batch {i}:\n\t{round(s['summary_score'],4)}" for i, s in enumerate(_summaries, start=1) ] sum_text_out = "\n".join(sum_text) history["Summary Scores"] = "

" scores_out = "\n".join(sum_scores) rt = round((time.perf_counter() - st) / 60, 2) logger.info(f"Runtime: {rt} minutes") html = "" html += f"

Runtime: {rt} minutes on CPU

" if msg is not None: html += msg html += "" return html, sum_text_out, scores_out def load_single_example_text( example_path: str or Path, ): """ load_single_example - a helper function for the gradio module to load examples Returns: list of str, the examples """ global name_to_path full_ex_path = name_to_path[example_path] full_ex_path = Path(full_ex_path) # load the examples into a list with open(full_ex_path, "r", encoding="utf-8", errors="ignore") as f: raw_text = f.read() text = clean(raw_text, lower=False) return text def load_uploaded_file(file_obj): """ load_uploaded_file - process an uploaded file Args: file_obj (POTENTIALLY list): Gradio file object inside a list Returns: str, the uploaded file contents """ # file_path = Path(file_obj[0].name) # check if mysterious file object is a list if isinstance(file_obj, list): file_obj = file_obj[0] file_path = Path(file_obj.name) try: with open(file_path, "r", encoding="utf-8", errors="ignore") as f: raw_text = f.read() text = clean(raw_text, lower=False) return text except Exception as e: logging.info(f"Trying to load file with path {file_path}, error: {e}") return "Error: Could not read file. Ensure that it is a valid text file with encoding UTF-8." if __name__ == "__main__": logger = logging.getLogger(__name__) logger.info("Starting up app") name_to_path = load_example_filenames(_here / "examples") logging.info(f"Loaded {len(name_to_path)} examples") demo = gr.Blocks( title="Summarize Long-Form Text", ) _examples = list(name_to_path.keys()) with demo: gr.Markdown("# Long-Form Summarization: LED & BookSum") gr.Markdown( "LED models ([model card](https://huggingface.co/pszemraj/led-large-book-summary)) fine-tuned to summarize long-form text. A [space with other models can be found here](https://huggingface.co/spaces/pszemraj/document-summarization)" ) with gr.Column(): gr.Markdown("## Load Inputs & Select Parameters") gr.Markdown( "Enter or upload text below, and it will be summarized [using the selected parameters](https://huggingface.co/blog/how-to-generate). " ) with gr.Row(): model_name = gr.Dropdown( choices=MODEL_OPTIONS, value=MODEL_OPTIONS[0], label="Model Name", ) num_beams = gr.Radio( choices=[2, 3, 4], label="Beam Search: # of Beams", value=2, ) gr.Markdown( "Load a a .txt - example or your own (_You may find [this OCR space](https://huggingface.co/spaces/pszemraj/pdf-ocr) useful_)" ) with gr.Row(): example_name = gr.Dropdown( _examples, label="Examples", value=random.choice(_examples), ) uploaded_file = gr.File( label="File Upload", file_count="single", type="file", ) with gr.Row(): input_text = gr.Textbox( lines=4, max_lines=12, label="Text to Summarize", placeholder="Enter text to summarize, the text will be cleaned and truncated on Spaces. Narrative, academic (both papers and lecture transcription), and article text work well. May take a bit to generate depending on the input text :)", ) with gr.Column(): load_examples_button = gr.Button( "Load Example", ) load_file_button = gr.Button("Upload File") gr.Markdown("---") with gr.Column(): gr.Markdown("## Generate Summary") gr.Markdown( "Summary generation should take approximately 1-2 minutes for most settings." ) summarize_button = gr.Button( "Summarize!", variant="primary", ) output_text = gr.HTML("

Output will appear below:

") gr.Markdown("### Summary Output") summary_text = gr.Textbox( label="Summary", placeholder="The generated summary will appear here" ) gr.Markdown( "The summary scores can be thought of as representing the quality of the summary. less-negative numbers (closer to 0) are better:" ) summary_scores = gr.Textbox( label="Summary Scores", placeholder="Summary scores will appear here" ) gr.Markdown("---") with gr.Column(): gr.Markdown("### Advanced Settings") with gr.Row(): length_penalty = gr.Slider( minimum=0.5, maximum=1.0, label="length penalty", value=0.7, step=0.05, ) token_batch_length = gr.Radio( choices=[512, 768, 1024, 1536], label="token batch length", value=1024, ) with gr.Row(): repetition_penalty = gr.Slider( minimum=1.0, maximum=5.0, label="repetition penalty", value=3.5, step=0.1, ) no_repeat_ngram_size = gr.Radio( choices=[2, 3, 4], label="no repeat ngram size", value=3, ) with gr.Column(): gr.Markdown("### About the Model") gr.Markdown( "- [This model](https://huggingface.co/pszemraj/led-large-book-summary) is a fine-tuned checkpoint of [allenai/led-large-16384](https://huggingface.co/allenai/led-large-16384) on the [BookSum dataset](https://arxiv.org/abs/2105.08209).The goal was to create a model that can generalize well and is useful in summarizing lots of text in academic and daily usage." ) gr.Markdown( "- The model can be used with tag [pszemraj/led-large-book-summary](https://huggingface.co/pszemraj/led-large-book-summary). See the model card for details on usage & a Colab notebook for a tutorial." ) gr.Markdown( "- **Update May 1, 2023:** Enabled faster inference times via `use_cache=True`, the number of words the model will processed has been increased! Not on this demo, but there is a [test model](https://huggingface.co/pszemraj/led-large-book-summary-continued) available: an extension of `led-large-book-summary`." ) gr.Markdown("---") load_examples_button.click( fn=load_single_example_text, inputs=[example_name], outputs=[input_text] ) load_file_button.click( fn=load_uploaded_file, inputs=uploaded_file, outputs=[input_text] ) summarize_button.click( fn=proc_submission, inputs=[ input_text, model_name, num_beams, token_batch_length, length_penalty, repetition_penalty, no_repeat_ngram_size, ], outputs=[output_text, summary_text, summary_scores], ) demo.launch( enable_queue=True, )