File size: 5,132 Bytes
fe0e9af
 
 
904400a
fe0e9af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66e7228
fe0e9af
 
 
 
 
 
 
 
 
 
 
904400a
 
fe0e9af
 
 
 
 
 
 
 
 
66e7228
fe0e9af
 
 
9b3e02d
fe0e9af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3247bd6
 
 
fe0e9af
 
 
 
 
 
f4f4797
fe0e9af
 
 
 
 
f4f4797
fe0e9af
 
 
f4f4797
66e7228
 
 
 
 
 
fe0e9af
 
 
 
f4f4797
fe0e9af
 
 
 
 
 
 
 
 
 
 
 
66e7228
 
fe0e9af
 
 
 
 
 
 
 
504e8b4
fe0e9af
 
 
66e7228
fe0e9af
 
66e7228
4fc786e
66e7228
4fc786e
66e7228
fe0e9af
 
 
 
 
 
f4f4797
fe0e9af
50085ad
504e8b4
66e7228
fe0e9af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b3e02d
4fc786e
b38cd28
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import logging
from pathlib import Path
import os
import re
import gradio as gr
import nltk
import torch
from cleantext import clean

from summarize import load_model_and_tokenizer, summarize_via_tokenbatches

_here = Path(__file__).parent

nltk.download("stopwords")  # TODO=find where this requirement originates from

import transformers

transformers.logging.set_verbosity_error()
logging.basicConfig()


def truncate_word_count(text, max_words=512):
    """
    truncate_word_count - a helper function for the gradio module
    Parameters
    ----------
    text : str, required, the text to be processed
    max_words : int, optional, the maximum number of words, default=512
    Returns
    -------
    dict, the text and whether it was truncated
    """
    # split on whitespace with regex
    words = re.split(r"\s+", text)
    processed = {}
    if len(words) > max_words:
        processed["was_truncated"] = True
        processed["truncated_text"] = " ".join(words[:max_words])
    else:
        processed["was_truncated"] = False
        processed["truncated_text"] = text
    return processed


def proc_submission(
    input_text: str,
    num_beams,
    token_batch_length,
    length_penalty,
    repetition_penalty,
    no_repeat_ngram_size,
    max_input_length: int = 512,
):
    """
    proc_submission - a helper function for the gradio module
    Parameters
    ----------
    input_text : str, required, the text to be processed
    max_input_length : int, optional, the maximum length of the input text, default=512
    Returns
    -------
    str of HTML, the interactive HTML form for the model
    """

    settings = {
        "length_penalty": length_penalty,
        "repetition_penalty": repetition_penalty,
        "no_repeat_ngram_size": no_repeat_ngram_size,
        "encoder_no_repeat_ngram_size": 4,
        "num_beams": num_beams,
        "early_stopping": True,
        "min_length": 10,
        "do_sample": False,
    }

    history = {}
    clean_text = clean(input_text, lower=False)
    processed = truncate_word_count(clean_text, max_input_length)
    if processed["was_truncated"]:
        tr_in = processed["truncated_text"]
        history["was_truncated"] = True
        msg = f"Input text was truncated to {max_input_length} characters."
        logging.warning(msg)
        history["WARNING"] = msg
    else:
        tr_in = input_text
        history["was_truncated"] = False

    _summaries = summarize_via_tokenbatches(
        tr_in,
        model,
        tokenizer,
        batch_length=token_batch_length,
        **settings,
    )
    sum_text = [s["summary"][0] for s in _summaries]
    sum_scores = [f"\n - {round(s['summary_score'],4)}" for s in _summaries]

    history["Summary Text"] = "\n\t".join(sum_text)
    history["Summary Scores"] = "\n".join(sum_scores)
    history["Input"] = tr_in
    html = ""
    for name, item in history.items():
        html += (
            f"<h2>{name}:</h2><hr><b>{item}</b><br><br>"
            if "summary" not in name.lower()
            else f"<h2>{name}:</h2><hr><b>{item}</b>"
        )

    html += ""

    return html


def load_examples(examples_dir="examples"):
    src = _here / examples_dir
    src.mkdir(exist_ok=True)
    examples = [f for f in src.glob("*.txt")]
    # load the examples into a list
    text_examples = []
    for example in examples:
        with open(example, "r") as f:
            text = f.read()
            text_examples.append([text, 2, 1024, 0.7, 3.5, 3])

    return text_examples


if __name__ == "__main__":

    model, tokenizer = load_model_and_tokenizer("pszemraj/led-large-book-summary")
    title = "Long-form Summarization: LED & BookSum"
    description = (
        "This is a simple example of using the LED model to summarize a long-form text. This model is a fine-tuned version of [allenai/led-large-16384](https://huggingface.co/allenai/led-large-16384) on the booksum dataset. the goal was to create a model that can generalize well and is useful in summarizing lots of text in academic and daily usage."
    )

    gr.Interface(
        proc_submission,
        inputs=[
            gr.inputs.Textbox(lines=10, label="input text"),
            gr.inputs.Slider(
                minimum=1, maximum=6, label="num_beams", default=4, step=1
            ),
            gr.inputs.Slider(
                minimum=512, maximum=2048, label="token_batch_length", default=1024, step=512,
            ),
            gr.inputs.Slider(
                minimum=0.5, maximum=1.1, label="length_penalty", default=0.7, step=0.05
            ),
            gr.inputs.Slider(
                minimum=1.0,
                maximum=5.0,
                label="repetition_penalty",
                default=3.5,
                step=0.1,
            ),
            gr.inputs.Slider(
                minimum=2, maximum=4, label="no_repeat_ngram_size", default=3, step=1
            ),
        ],
        outputs="html",
        examples_per_page=4,
        title=title,
        description=description,
        examples=load_examples(),
        cache_examples=False,
    ).launch(enable_queue=True, )