File size: 4,841 Bytes
fe0e9af
 
 
904400a
fe0e9af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66e7228
fe0e9af
 
 
 
 
 
 
 
 
 
 
904400a
 
fe0e9af
 
 
 
 
 
 
 
 
66e7228
fe0e9af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66e7228
 
 
 
 
 
 
fe0e9af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66e7228
 
fe0e9af
 
 
 
 
 
 
 
281a517
fe0e9af
 
 
66e7228
fe0e9af
 
66e7228
fe0e9af
66e7228
 
 
fe0e9af
 
 
 
 
 
 
 
ca046ea
 
66e7228
fe0e9af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66e7228
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import logging
from pathlib import Path
import os
import re
import gradio as gr
import nltk
import torch
from cleantext import clean

from summarize import load_model_and_tokenizer, summarize_via_tokenbatches

_here = Path(__file__).parent

nltk.download("stopwords")  # TODO=find where this requirement originates from

import transformers

transformers.logging.set_verbosity_error()
logging.basicConfig()


def truncate_word_count(text, max_words=512):
    """
    truncate_word_count - a helper function for the gradio module
    Parameters
    ----------
    text : str, required, the text to be processed
    max_words : int, optional, the maximum number of words, default=512
    Returns
    -------
    dict, the text and whether it was truncated
    """
    # split on whitespace with regex
    words = re.split(r"\s+", text)
    processed = {}
    if len(words) > max_words:
        processed["was_truncated"] = True
        processed["truncated_text"] = " ".join(words[:max_words])
    else:
        processed["was_truncated"] = False
        processed["truncated_text"] = text
    return processed


def proc_submission(
    input_text: str,
    num_beams,
    length_penalty,
    repetition_penalty,
    no_repeat_ngram_size,
    token_batch_length,
    max_input_length: int = 512,
):
    """
    proc_submission - a helper function for the gradio module
    Parameters
    ----------
    input_text : str, required, the text to be processed
    max_input_length : int, optional, the maximum length of the input text, default=512
    Returns
    -------
    str of HTML, the interactive HTML form for the model
    """

    settings = {
        "length_penalty": length_penalty,
        "repetition_penalty": repetition_penalty,
        "no_repeat_ngram_size": no_repeat_ngram_size,
        "encoder_no_repeat_ngram_size": 4,
        "num_beams": num_beams,
    }

    history = {}
    clean_text = clean(input_text, lower=False)
    processed = truncate_word_count(clean_text, max_input_length)
    if processed["was_truncated"]:
        history["input_text"] = processed["truncated_text"]
        history["was_truncated"] = True
        msg = f"Input text was truncated to {max_input_length} characters."
        logging.warning(msg)
        history["WARNING"] = msg
    else:
        history["input_text"] = input_text
        history["was_truncated"] = False

    _summaries = summarize_via_tokenbatches(
        history["input_text"],
        model,
        tokenizer,
        batch_length=token_batch_length,
        **settings,
    )
    sum_text = [s["summary"][0] for s in _summaries]
    sum_scores = [f"\n - {round(s['summary_score'],4)}" for s in _summaries]

    history["Input"] = input_text
    history["Summary Text"] = "\n\t".join(sum_text)
    history["Summary Scores"] = "\n".join(sum_scores)
    html = ""
    for name, item in history.items():
        html += (
            f"<h2>{name}:</h2><hr><b>{item}</b><br><br>"
            if "summary" not in name.lower()
            else f"<h2>{name}:</h2><hr><b>{item}</b>"
        )

    html += ""

    return html


def load_examples(examples_dir="examples"):
    src = _here / examples_dir
    src.mkdir(exist_ok=True)
    examples = [f for f in src.glob("*.txt")]
    # load the examples into a list
    text_examples = []
    for example in examples:
        with open(example, "r") as f:
            text = f.read()
            text_examples.append([text, 4, 0.7, 2048, 3.5, 3])

    return text_examples


if __name__ == "__main__":

    model, tokenizer = load_model_and_tokenizer("pszemraj/led-large-book-summary")
    title = "Long-form text summarization with LED on the BookSumm dataset"
    description = (
        "This is a simple example of using the LED model to summarize a long-form text."
    )

    gr.Interface(
        proc_submission,
        inputs=[
            gr.inputs.Textbox(lines=10, label="input text"),
            gr.inputs.Slider(
                minimum=4, maximum=10, label="num_beams", default=4, step=1
            ),
             gr.inputs.Slider(
                minimum=512, maximum=4096, label="token_batch_length", default=2048, step=512,
            ),
            gr.inputs.Slider(
                minimum=0.5, maximum=1.1, label="length_penalty", default=0.7, step=0.05
            ),
            gr.inputs.Slider(
                minimum=1.0,
                maximum=5.0,
                label="repetition_penalty",
                default=3.5,
                step=0.1,
            ),
            gr.inputs.Slider(
                minimum=2, maximum=4, label="no_repeat_ngram_size", default=3, step=1
            ),
        ],
        outputs="html",
        examples_per_page=4,
        title=title,
        description=description,
        examples=load_examples(),
    ).launch(enable_queue=True, share=True)