Peter commited on
Commit
fe0e9af
1 Parent(s): 4a607b7

:sparkles: add app mwe

Browse files
Files changed (1) hide show
  1. app.py +153 -0
app.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from pathlib import Path
3
+ import os
4
+ import gradio as gr
5
+ import nltk
6
+ import torch
7
+ from cleantext import clean
8
+
9
+ from summarize import load_model_and_tokenizer, summarize_via_tokenbatches
10
+
11
+ _here = Path(__file__).parent
12
+
13
+ nltk.download("stopwords") # TODO=find where this requirement originates from
14
+
15
+ import transformers
16
+
17
+ transformers.logging.set_verbosity_error()
18
+ logging.basicConfig()
19
+
20
+ def truncate_word_count(text, max_words=512):
21
+ """
22
+ truncate_word_count - a helper function for the gradio module
23
+ Parameters
24
+ ----------
25
+ text : str, required, the text to be processed
26
+ max_words : int, optional, the maximum number of words, default=512
27
+ Returns
28
+ -------
29
+ dict, the text and whether it was truncated
30
+ """
31
+ words = text.split()
32
+ processed = {}
33
+ if len(words) > max_words:
34
+ processed["was_truncated"] = True
35
+ processed["truncated_text"] = " ".join(words[:max_words])
36
+ else:
37
+ processed["was_truncated"] = False
38
+ processed["truncated_text"] = text
39
+ return processed
40
+
41
+ def proc_submission(
42
+ input_text: str,
43
+ num_beams,
44
+ length_penalty,
45
+ repetition_penalty,
46
+ no_repeat_ngram_size,
47
+ token_batch_length,
48
+ max_input_length: int = 512,
49
+ ):
50
+ """
51
+ proc_submission - a helper function for the gradio module
52
+ Parameters
53
+ ----------
54
+ input_text : str, required, the text to be processed
55
+ max_input_length : int, optional, the maximum length of the input text, default=512
56
+ Returns
57
+ -------
58
+ str of HTML, the interactive HTML form for the model
59
+ """
60
+
61
+ settings = {
62
+ "length_penalty": length_penalty,
63
+ "repetition_penalty": repetition_penalty,
64
+ "no_repeat_ngram_size": no_repeat_ngram_size,
65
+ "encoder_no_repeat_ngram_size": 4,
66
+ "num_beams": num_beams,
67
+ }
68
+
69
+ history = {}
70
+ clean_text = clean(input_text, lower=False)
71
+ processed = truncate_word_count(clean_text, max_input_length)
72
+ if processed["was_truncated"]:
73
+ history["input_text"] = processed["truncated_text"]
74
+ history["was_truncated"] = True
75
+ msg = f"Input text was truncated to {max_input_length} characters."
76
+ logging.warning(msg)
77
+ history["WARNING"] = msg
78
+ else:
79
+ history["input_text"] = input_text
80
+ history["was_truncated"] = False
81
+
82
+ _summaries = summarize_via_tokenbatches(
83
+ history["input_text"],
84
+ model, tokenizer,
85
+ batch_length=token_batch_length,
86
+ **settings,
87
+ )
88
+ sum_text = [s['summary'][0] for s in _summaries]
89
+ sum_scores = [f"\n - {round(s['summary_score'],4)}" for s in _summaries]
90
+
91
+
92
+ history["Input"] = input_text
93
+ history["Summary Text"] = "\n\t".join(sum_text)
94
+ history["Summary Scores"] = "\n".join(sum_scores)
95
+ html = ""
96
+ for name, item in history.items():
97
+ html += (
98
+ f"<h2>{name}:</h2><hr><b>{item}</b><br><br>"
99
+ if "summary" not in name.lower()
100
+ else f"<h2>{name}:</h2><hr><b>{item}</b>"
101
+ )
102
+
103
+ html += ""
104
+
105
+ return html
106
+
107
+ def load_examples(examples_dir='examples'):
108
+ src = _here / examples_dir
109
+ src.mkdir(exist_ok=True)
110
+ examples = [f for f in src.glob("*.txt")]
111
+ # load the examples into a list
112
+ text_examples = []
113
+ for example in examples:
114
+ with open(example, "r") as f:
115
+ text = f.read()
116
+ text_examples.append([text, 4, 2048, 0.7,3.5,3])
117
+
118
+ return text_examples
119
+
120
+ if __name__ == "__main__":
121
+
122
+ model, tokenizer = load_model_and_tokenizer('pszemraj/led-large-book-summary')
123
+ title = "Long-form text summarization with LED on the BookSumm dataset"
124
+ description = "This is a simple example of using the LED model to summarize a long-form text."
125
+
126
+ gr.Interface(
127
+ proc_submission,
128
+ inputs=[
129
+ gr.inputs.Textbox(lines=10, label="input text"),
130
+ gr.inputs.Slider(
131
+ minimum=4, maximum=10, label="num_beams", default=4, step=1
132
+ ),
133
+ gr.Dropdown(choices=[512, 1024, 2048, 4096], label="token_batch_length", default=2048),
134
+ gr.inputs.Slider(
135
+ minimum=0.5, maximum=1.1, label="length_penalty", default=0.7, step=0.05
136
+ ),
137
+ gr.inputs.Slider(
138
+ minimum=1.0,
139
+ maximum=5.0,
140
+ label="repetition_penalty",
141
+ default=3.5,
142
+ step=0.1,
143
+ ),
144
+ gr.inputs.Slider(
145
+ minimum=2, maximum=4, label="no_repeat_ngram_size", default=3, step=1
146
+ ),
147
+ ],
148
+ outputs="html",
149
+ examples_per_page=4,
150
+ title=title,
151
+ description=description,
152
+ examples=load_examples(),
153
+ ).launch(enable_queue=True, share=True)