File size: 6,194 Bytes
fe0e9af
904400a
8dbbc84
 
fe0e9af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66e7228
fe0e9af
 
 
 
 
 
 
 
 
 
 
904400a
 
fe0e9af
 
 
 
 
 
 
 
 
66e7228
fe0e9af
 
98a3ea7
fe0e9af
9b3e02d
fe0e9af
 
 
6fc8143
fe0e9af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3247bd6
 
 
fe0e9af
 
 
 
98a3ea7
fe0e9af
 
f4f4797
8dbbc84
fe0e9af
 
 
f4f4797
fe0e9af
 
 
f4f4797
98a3ea7
 
66e7228
 
 
ecba037
3b66adc
 
 
 
fe0e9af
ecba037
fe0e9af
f4f4797
fe0e9af
8281a66
fe0e9af
 
 
 
 
 
 
 
 
 
 
66e7228
 
fe0e9af
 
 
 
 
 
 
 
3ca941f
fe0e9af
 
 
66e7228
fe0e9af
 
66e7228
98a3ea7
8dbbc84
afa6ede
fe0e9af
 
 
 
24e11fd
 
 
98a3ea7
24e11fd
df69f18
3b66adc
 
fe0e9af
3ca941f
fe0e9af
50085ad
afa6ede
ecba037
afa6ede
ecba037
 
66e7228
fe0e9af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ca941f
fe0e9af
 
8dbbc84
9b3e02d
3ca941f
3b66adc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import logging
import re
from pathlib import Path

import gradio as gr
import nltk
from cleantext import clean

from summarize import load_model_and_tokenizer, summarize_via_tokenbatches

_here = Path(__file__).parent

nltk.download("stopwords")  # TODO=find where this requirement originates from

import transformers

transformers.logging.set_verbosity_error()
logging.basicConfig()


def truncate_word_count(text, max_words=512):
    """
    truncate_word_count - a helper function for the gradio module
    Parameters
    ----------
    text : str, required, the text to be processed
    max_words : int, optional, the maximum number of words, default=512
    Returns
    -------
    dict, the text and whether it was truncated
    """
    # split on whitespace with regex
    words = re.split(r"\s+", text)
    processed = {}
    if len(words) > max_words:
        processed["was_truncated"] = True
        processed["truncated_text"] = " ".join(words[:max_words])
    else:
        processed["was_truncated"] = False
        processed["truncated_text"] = text
    return processed


def proc_submission(
    input_text: str,
    model_size: str,
    num_beams,
    token_batch_length,
    length_penalty,
    repetition_penalty,
    no_repeat_ngram_size,
    max_input_length: int = 512,
):
    """
    proc_submission - a helper function for the gradio module
    Parameters
    ----------
    input_text : str, required, the text to be processed
    max_input_length : int, optional, the maximum length of the input text, default=512
    Returns
    -------
    str of HTML, the interactive HTML form for the model
    """

    settings = {
        "length_penalty": length_penalty,
        "repetition_penalty": repetition_penalty,
        "no_repeat_ngram_size": no_repeat_ngram_size,
        "encoder_no_repeat_ngram_size": 4,
        "num_beams": num_beams,
        "early_stopping": True,
        "min_length": 10,
        "do_sample": False,
    }

    history = {}
    clean_text = clean(input_text, lower=False)
    max_input_length = 1024 if model_size == "base" else max_input_length
    processed = truncate_word_count(clean_text, max_input_length)
    if processed["was_truncated"]:
        tr_in = processed["truncated_text"]
        msg = f"Input text was truncated to {max_input_length} words (based on whitespace)"
        logging.warning(msg)
        history["WARNING"] = msg
    else:
        tr_in = input_text
        history["was_truncated"] = False

    _summaries = summarize_via_tokenbatches(
        tr_in,
        model_sm if model_size == "base" else model,
        tokenizer_sm if model_size == "base" else tokenizer,
        batch_length=token_batch_length,
        **settings,
    )
    sum_text = [f"Section {i}: " + s["summary"][0] for i, s in enumerate(_summaries)]
    sum_scores = [
        f"\n - Section {i}: {round(s['summary_score'],4)}"
        for i, s in enumerate(_summaries)
    ]

    history["Summary Text"] = "<br>".join(sum_text)
    history["Summary Scores"] = "\n".join(sum_scores)
    history["Input"] = tr_in
    html = ""

    for name, item in history.items():
        html += (
            f"<h2>{name}:</h2><hr><b>{item}</b><br><br>"
            if "summary" not in name.lower()
            else f"<h2>{name}:</h2><hr><b>{item}</b>"
        )

    html += ""

    return html


def load_examples(examples_dir="examples"):
    src = _here / examples_dir
    src.mkdir(exist_ok=True)
    examples = [f for f in src.glob("*.txt")]
    # load the examples into a list
    text_examples = []
    for example in examples:
        with open(example, "r") as f:
            text = f.read()
            text_examples.append([text, "large", 1, 512, 0.7, 3.5, 3])

    return text_examples


if __name__ == "__main__":

    model, tokenizer = load_model_and_tokenizer("pszemraj/led-large-book-summary")
    model_sm, tokenizer_sm = load_model_and_tokenizer("pszemraj/led-base-book-summary")
    title = "Long-Form Summarization: LED & BookSum"
    description = "A simple demo of how to use a fine-tuned LED model to summarize long-form text. [This model](https://huggingface.co/pszemraj/led-large-book-summary) is a fine-tuned version of [allenai/led-large-16384](https://huggingface.co/allenai/led-large-16384) on the [BookSum dataset](https://arxiv.org/abs/2105.08209). The goal was to create a model that can generalize well and is useful in summarizing lots of text in academic and daily usage."

    gr.Interface(
        proc_submission,
        inputs=[
            gr.inputs.Textbox(
                lines=10,
                label="input text",
                placeholder="Enter text to summarize, the text will be cleaned and truncated on Spaces. Narrative, academic (both papers and lecture transcription), and article text work well. May take a bit to generate depending on the input text :)",
            ),
            gr.inputs.Radio(
                choices=["base", "large"], label="model size", default="base"
            ),
            gr.inputs.Slider(
                minimum=1, maximum=4, label="num_beams", default=1, step=1
            ),
            gr.inputs.Slider(
                minimum=512,
                maximum=1024,
                label="token_batch_length",
                default=512,
                step=256,
            ),
            gr.inputs.Slider(
                minimum=0.5, maximum=1.1, label="length_penalty", default=0.7, step=0.05
            ),
            gr.inputs.Slider(
                minimum=1.0,
                maximum=5.0,
                label="repetition_penalty",
                default=3.5,
                step=0.1,
            ),
            gr.inputs.Slider(
                minimum=2, maximum=4, label="no_repeat_ngram_size", default=3, step=1
            ),
        ],
        outputs="html",
        examples_per_page=2,
        title=title,
        description=description,
        article="The model can be used with tag [pszemraj/led-large-book-summary](https://huggingface.co/pszemraj/led-large-book-summary). See the model card for details on usage & a notebook for a tutorial.",
        examples=load_examples(),
        cache_examples=True,
    ).launch()