Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
✨ mwe working aggregation
Browse filesSigned-off-by: peter szemraj <peterszemraj@gmail.com>
- aggregate.py +158 -67
- app.py +91 -10
aggregate.py
CHANGED
@@ -1,10 +1,12 @@
|
|
1 |
-
|
2 |
import logging
|
3 |
import time
|
4 |
|
5 |
import torch
|
6 |
from transformers import GenerationConfig, pipeline
|
7 |
|
|
|
|
|
8 |
# Setting up logging
|
9 |
logging.basicConfig(
|
10 |
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
@@ -12,94 +14,182 @@ logging.basicConfig(
|
|
12 |
|
13 |
|
14 |
class BatchAggregator:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
def __init__(
|
16 |
self, model_name: str = "pszemraj/bart-large-mnli-dolly_hhrlhf-v1", **kwargs
|
17 |
):
|
|
|
|
|
18 |
self.logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
self.model_name = model_name
|
20 |
-
self.
|
21 |
-
self.
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
|
|
|
|
|
|
28 |
try:
|
29 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
except Exception as e:
|
31 |
-
self.logger.
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
try:
|
33 |
-
self.aggregator.model
|
34 |
-
|
35 |
-
)
|
36 |
except Exception as e:
|
37 |
-
self.logger.warning(
|
38 |
-
|
39 |
-
|
40 |
-
self.
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
)
|
54 |
|
55 |
-
if "bart" in model_name.lower():
|
56 |
-
self.logger.info("Using BART model, updating generation config")
|
57 |
-
upd = {
|
58 |
-
"num_beams": 8,
|
59 |
-
"repetition_penalty": 1.3,
|
60 |
-
"length_penalty": 1.0,
|
61 |
-
"_from_model_config": False,
|
62 |
-
"max_new_tokens": 256,
|
63 |
-
"min_new_tokens": 32,
|
64 |
-
"no_repeat_ngram_size": 3,
|
65 |
-
"encoder_no_repeat_ngram_size": 6,
|
66 |
-
}
|
67 |
-
self.aggregator.model.generation_config.update(**upd)
|
68 |
-
if self.model_name != "pszemraj/bart-large-mnli-dolly_hhrlhf-v1":
|
69 |
-
self.logger.info("Updating generation config with defaults")
|
70 |
-
self.update_generation_config()
|
71 |
self.logger.info(self.aggregator.model.generation_config.to_json_string())
|
72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
def update_generation_config(self, **kwargs):
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
|
|
|
|
91 |
def infer_aggregate(
|
92 |
self,
|
93 |
text_list: list,
|
94 |
-
instruction: str =
|
95 |
**kwargs,
|
96 |
-
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
joined_text = "\n".join(text_list)
|
98 |
prompt = f"{instruction}\n\n{joined_text}\n"
|
99 |
if kwargs:
|
100 |
self.update_generation_config(**kwargs)
|
101 |
st = time.perf_counter()
|
102 |
-
self.logger.info(f"
|
103 |
result = self.aggregator(
|
104 |
prompt,
|
105 |
generation_config=self.aggregator.model.generation_config,
|
@@ -110,7 +200,8 @@ class BatchAggregator:
|
|
110 |
)
|
111 |
return result
|
112 |
|
113 |
-
def count_tokens(self, text: str):
|
|
|
114 |
return (
|
115 |
len(self.aggregator.tokenizer.encode(text, truncation=False, padding=False))
|
116 |
if text
|
|
|
1 |
+
import pprint as pp
|
2 |
import logging
|
3 |
import time
|
4 |
|
5 |
import torch
|
6 |
from transformers import GenerationConfig, pipeline
|
7 |
|
8 |
+
from utils import compare_model_size
|
9 |
+
|
10 |
# Setting up logging
|
11 |
logging.basicConfig(
|
12 |
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
|
|
14 |
|
15 |
|
16 |
class BatchAggregator:
|
17 |
+
CONFIGURED_MODELS = [
|
18 |
+
"pszemraj/bart-large-mnli-dolly_hhrlhf-v1"
|
19 |
+
] # TODO: Add models here
|
20 |
+
DEFAULT_INSTRUCTION = "Write a comprehensive yet concise summary that pulls together the main points of the following text:"
|
21 |
+
GENERIC_CONFIG = GenerationConfig(
|
22 |
+
num_beams=8,
|
23 |
+
early_stopping=True,
|
24 |
+
do_sample=False,
|
25 |
+
min_new_tokens=32,
|
26 |
+
max_new_tokens=256,
|
27 |
+
repetition_penalty=1.1,
|
28 |
+
length_penalty=1.4,
|
29 |
+
no_repeat_ngram_size=4,
|
30 |
+
encoder_no_repeat_ngram_size=5,
|
31 |
+
)
|
32 |
+
|
33 |
def __init__(
|
34 |
self, model_name: str = "pszemraj/bart-large-mnli-dolly_hhrlhf-v1", **kwargs
|
35 |
):
|
36 |
+
self.device = None
|
37 |
+
self.is_compiled = False
|
38 |
self.logger = logging.getLogger(__name__)
|
39 |
+
self.init_model(model_name)
|
40 |
+
|
41 |
+
def init_model(self, model_name: str) -> None:
|
42 |
+
"""
|
43 |
+
Initialize the model.
|
44 |
+
|
45 |
+
:param model_name: The name of the model to use.
|
46 |
+
"""
|
47 |
+
# Free up memory
|
48 |
+
if torch.cuda.is_available():
|
49 |
+
torch.cuda.empty_cache()
|
50 |
+
|
51 |
+
self.logger.info(f"Setting model to {model_name}")
|
52 |
self.model_name = model_name
|
53 |
+
self.aggregator = self._create_pipeline(model_name)
|
54 |
+
self._configure_model()
|
55 |
+
# update the generation config with the specific tokenizer
|
56 |
+
tokenizer_params = {
|
57 |
+
"decoder_start_token_id": 0
|
58 |
+
if "t5" in model_name.lower()
|
59 |
+
else self.aggregator.tokenizer.eos_token_id,
|
60 |
+
"eos_token_id": 1
|
61 |
+
if "t5" in model_name.lower()
|
62 |
+
else self.aggregator.tokenizer.eos_token_id,
|
63 |
+
"pad_token_id": 0
|
64 |
+
if "t5" in model_name.lower()
|
65 |
+
else self.aggregator.tokenizer.pad_token_id,
|
66 |
+
}
|
67 |
+
self.update_generation_config(**tokenizer_params)
|
68 |
+
|
69 |
+
def _create_pipeline(
|
70 |
+
self, model_name: str = "pszemraj/bart-large-mnli-dolly_hhrlhf-v1"
|
71 |
+
) -> pipeline:
|
72 |
+
"""
|
73 |
+
_create_pipeline creates a pipeline for the model.
|
74 |
+
|
75 |
+
:param str model_name: model name to use, default: "pszemraj/bart-large-mnli-dolly_hhrlhf-v1"
|
76 |
+
:return pipeline: the pipeline for the model
|
77 |
|
78 |
+
:raises Exception: if the pipeline cannot be created
|
79 |
+
"""
|
80 |
+
self.device = 0 if torch.cuda.is_available() else -1
|
81 |
try:
|
82 |
+
self.logger.info(
|
83 |
+
f"Creating pipeline with model {model_name} on device {self.device}"
|
84 |
+
)
|
85 |
+
return pipeline(
|
86 |
+
"text2text-generation",
|
87 |
+
model_name,
|
88 |
+
device=self.device,
|
89 |
+
torch_dtype=torch.float32,
|
90 |
+
)
|
91 |
except Exception as e:
|
92 |
+
self.logger.error(f"Failed to create pipeline: {e}")
|
93 |
+
raise
|
94 |
+
|
95 |
+
def _configure_model(self):
|
96 |
+
"""
|
97 |
+
Configure the model for generation.
|
98 |
+
"""
|
99 |
try:
|
100 |
+
self.aggregator.model = torch.compile(self.aggregator.model)
|
101 |
+
self.is_compiled = True
|
|
|
102 |
except Exception as e:
|
103 |
+
self.logger.warning(f"Could not compile model with Torch 2.0: {e}")
|
104 |
+
|
105 |
+
if self.model_name not in self.CONFIGURED_MODELS:
|
106 |
+
self.logger.info("Setting generation config to general defaults")
|
107 |
+
self._set_default_generation_config()
|
108 |
+
else:
|
109 |
+
try:
|
110 |
+
self.logger.info("Loading generation config from hub")
|
111 |
+
self.aggregator.model.generation_config = (
|
112 |
+
GenerationConfig.from_pretrained(self.model_name)
|
113 |
+
)
|
114 |
+
except Exception as e:
|
115 |
+
self.logger.warning(
|
116 |
+
f"Could not load generation config, using defaults: {e}"
|
117 |
+
)
|
118 |
+
self._set_default_generation_config()
|
|
|
119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
self.logger.info(self.aggregator.model.generation_config.to_json_string())
|
121 |
|
122 |
+
def _set_default_generation_config(self):
|
123 |
+
"""
|
124 |
+
Set the default generation configuration for the model.
|
125 |
+
"""
|
126 |
+
self.aggregator.model.generation_config = self.GENERIC_CONFIG
|
127 |
+
|
128 |
+
if "bart" in self.model_name.lower():
|
129 |
+
self.logger.info("Using BART model, updating generation config")
|
130 |
+
upd = {
|
131 |
+
"num_beams": 8,
|
132 |
+
"repetition_penalty": 1.3,
|
133 |
+
"length_penalty": 1.0,
|
134 |
+
"_from_model_config": False,
|
135 |
+
"max_new_tokens": 256,
|
136 |
+
"min_new_tokens": 32,
|
137 |
+
"no_repeat_ngram_size": 3,
|
138 |
+
"encoder_no_repeat_ngram_size": 6,
|
139 |
+
} # TODO: clean up
|
140 |
+
self.aggregator.model.generation_config.update(**upd)
|
141 |
+
|
142 |
+
if (
|
143 |
+
"large"
|
144 |
+
or "xl" in self.model_name.lower()
|
145 |
+
or compare_model_size(self.model_name, 500)
|
146 |
+
):
|
147 |
+
upd = {"num_beams": 4}
|
148 |
+
self.update_generation_config(**upd)
|
149 |
+
|
150 |
def update_generation_config(self, **kwargs):
|
151 |
+
"""
|
152 |
+
Update the generation configuration with the specified parameters.
|
153 |
+
|
154 |
+
Args:
|
155 |
+
**kwargs: The parameters to update in the generation configuration.
|
156 |
+
"""
|
157 |
+
self.logger.info(f"Updating generation config with {pp.pformat(kwargs)}")
|
158 |
+
|
159 |
+
self.aggregator.model.generation_config.update(**kwargs)
|
160 |
+
|
161 |
+
def update_loglevel(self, level: str = "INFO"):
|
162 |
+
"""
|
163 |
+
Update the log level.
|
164 |
+
|
165 |
+
Args:
|
166 |
+
level (str): The log level to set. Defaults to "INFO".
|
167 |
+
"""
|
168 |
+
self.logger.setLevel(level)
|
169 |
+
|
170 |
def infer_aggregate(
|
171 |
self,
|
172 |
text_list: list,
|
173 |
+
instruction: str = DEFAULT_INSTRUCTION,
|
174 |
**kwargs,
|
175 |
+
) -> str:
|
176 |
+
f"""
|
177 |
+
Generate a summary of the specified texts.
|
178 |
+
|
179 |
+
Args:
|
180 |
+
text_list (list): The texts to summarize.
|
181 |
+
instruction (str): The instruction for the summary. Defaults to {self.DEFAULT_INSTRUCTION}.
|
182 |
+
**kwargs: Additional parameters to update in the generation configuration.
|
183 |
+
|
184 |
+
Returns:
|
185 |
+
The generated summary.
|
186 |
+
"""
|
187 |
joined_text = "\n".join(text_list)
|
188 |
prompt = f"{instruction}\n\n{joined_text}\n"
|
189 |
if kwargs:
|
190 |
self.update_generation_config(**kwargs)
|
191 |
st = time.perf_counter()
|
192 |
+
self.logger.info(f"inference on {len(text_list)} texts ...")
|
193 |
result = self.aggregator(
|
194 |
prompt,
|
195 |
generation_config=self.aggregator.model.generation_config,
|
|
|
200 |
)
|
201 |
return result
|
202 |
|
203 |
+
def count_tokens(self, text: str) -> int:
|
204 |
+
"""count the number of tokens in a text"""
|
205 |
return (
|
206 |
len(self.aggregator.tokenizer.encode(text, truncation=False, padding=False))
|
207 |
if text
|
app.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
"""
|
2 |
-
app.py - the main module for the gradio app
|
3 |
|
4 |
Usage:
|
5 |
python app.py
|
@@ -19,6 +19,7 @@ import random
|
|
19 |
import re
|
20 |
import time
|
21 |
from pathlib import Path
|
|
|
22 |
|
23 |
os.environ["USE_TORCH"] = "1"
|
24 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
@@ -31,16 +32,18 @@ logging.basicConfig(
|
|
31 |
import gradio as gr
|
32 |
import nltk
|
33 |
import torch
|
|
|
34 |
from cleantext import clean
|
35 |
from doctr.models import ocr_predictor
|
36 |
-
|
37 |
from pdf2text import convert_PDF_to_Text
|
38 |
from summarize import load_model_and_tokenizer, summarize_via_tokenbatches
|
39 |
from utils import (
|
|
|
40 |
load_example_filenames,
|
41 |
saves_summary,
|
42 |
textlist2html,
|
43 |
truncate_word_count,
|
|
|
44 |
)
|
45 |
|
46 |
_here = Path(__file__).parent
|
@@ -57,10 +60,76 @@ MODEL_OPTIONS = [
|
|
57 |
"pszemraj/pegasus-x-large-book-summary",
|
58 |
] # models users can choose from
|
59 |
|
|
|
|
|
60 |
# if duplicating space,, uncomment this line to adjust the max words
|
61 |
# os.environ["APP_MAX_WORDS"] = str(2048) # set the max words to 2048
|
62 |
# os.environ["APP_OCR_MAX_PAGES"] = str(40) # set the max pages to 40
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
def predict(
|
66 |
input_text: str,
|
@@ -128,6 +197,7 @@ def proc_submission(
|
|
128 |
str in HTML format, string of the summary, str of score
|
129 |
"""
|
130 |
|
|
|
131 |
settings = {
|
132 |
"length_penalty": float(length_penalty),
|
133 |
"repetition_penalty": float(repetition_penalty),
|
@@ -208,7 +278,6 @@ def proc_submission(
|
|
208 |
# save to file
|
209 |
settings["model_name"] = model_name
|
210 |
saved_file = saves_summary(summarize_output=_summaries, outpath=None, **settings)
|
211 |
-
|
212 |
return html, full_summary, scores_out, saved_file
|
213 |
|
214 |
|
@@ -361,7 +430,7 @@ if __name__ == "__main__":
|
|
361 |
summarize_button = gr.Button(
|
362 |
"Summarize!",
|
363 |
variant="primary",
|
364 |
-
)
|
365 |
output_text = gr.HTML("<p><em>Output will appear below:</em></p>")
|
366 |
with gr.Column():
|
367 |
gr.Markdown("#### Results & Scores")
|
@@ -384,11 +453,19 @@ if __name__ == "__main__":
|
|
384 |
label="Summary Scores",
|
385 |
placeholder="Summary scores will appear here",
|
386 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
387 |
|
388 |
-
gr.Markdown("#### **Summary Output**")
|
389 |
-
summary_text = gr.HTML(
|
390 |
-
label="Summary", value="<i>Summary will appear here!</i>"
|
391 |
-
)
|
392 |
gr.Markdown("---")
|
393 |
with gr.Column():
|
394 |
gr.Markdown("### Advanced Settings")
|
@@ -456,5 +533,9 @@ if __name__ == "__main__":
|
|
456 |
],
|
457 |
outputs=[output_text, summary_text, summary_scores, text_file],
|
458 |
)
|
459 |
-
|
460 |
-
|
|
|
|
|
|
|
|
|
|
1 |
"""
|
2 |
+
app.py - the main module for the gradio app for summarization
|
3 |
|
4 |
Usage:
|
5 |
python app.py
|
|
|
19 |
import re
|
20 |
import time
|
21 |
from pathlib import Path
|
22 |
+
import pprint as pp
|
23 |
|
24 |
os.environ["USE_TORCH"] = "1"
|
25 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
|
32 |
import gradio as gr
|
33 |
import nltk
|
34 |
import torch
|
35 |
+
from aggregate import BatchAggregator
|
36 |
from cleantext import clean
|
37 |
from doctr.models import ocr_predictor
|
|
|
38 |
from pdf2text import convert_PDF_to_Text
|
39 |
from summarize import load_model_and_tokenizer, summarize_via_tokenbatches
|
40 |
from utils import (
|
41 |
+
extract_batches,
|
42 |
load_example_filenames,
|
43 |
saves_summary,
|
44 |
textlist2html,
|
45 |
truncate_word_count,
|
46 |
+
remove_stagnant_files,
|
47 |
)
|
48 |
|
49 |
_here = Path(__file__).parent
|
|
|
60 |
"pszemraj/pegasus-x-large-book-summary",
|
61 |
] # models users can choose from
|
62 |
|
63 |
+
SUMMARY_PLACEHOLDER = "<p><em>Output will appear below:</em></p>"
|
64 |
+
|
65 |
# if duplicating space,, uncomment this line to adjust the max words
|
66 |
# os.environ["APP_MAX_WORDS"] = str(2048) # set the max words to 2048
|
67 |
# os.environ["APP_OCR_MAX_PAGES"] = str(40) # set the max pages to 40
|
68 |
|
69 |
+
aggregator = BatchAggregator("MBZUAI/LaMini-Flan-T5-783M")
|
70 |
+
|
71 |
+
|
72 |
+
def aggregate_text(
|
73 |
+
summary_text: str,
|
74 |
+
text_file: gr.inputs.File = None,
|
75 |
+
):
|
76 |
+
"""
|
77 |
+
Aggregate the text from the batches.
|
78 |
+
|
79 |
+
NOTE: you should probably include passing the BatchAggregator object as a parameter if using this code
|
80 |
+
outside of this file.
|
81 |
+
:param batches_html: The batches to aggregate, in html format
|
82 |
+
"""
|
83 |
+
if summary_text is None or summary_text == SUMMARY_PLACEHOLDER:
|
84 |
+
logging.error("No text provided. Make sure a summary has been generated first.")
|
85 |
+
return "Error: No text provided. Make sure a summary has been generated first."
|
86 |
+
|
87 |
+
try:
|
88 |
+
extracted_batches = extract_batches(summary_text)
|
89 |
+
except Exception as e:
|
90 |
+
logging.info(summary_text)
|
91 |
+
logging.info(f"the batches html is: {type(summary_text)}")
|
92 |
+
return f"Error: unable to extract batches - check input: {e}"
|
93 |
+
if not extracted_batches:
|
94 |
+
logging.error("unable to extract batches - check input")
|
95 |
+
return "Error: unable to extract batches - check input"
|
96 |
+
|
97 |
+
out_path = None
|
98 |
+
if text_file is not None:
|
99 |
+
out_path = text_file.name # assuming name attribute stores the file path
|
100 |
+
|
101 |
+
content_batches = [batch["content"] for batch in extracted_batches]
|
102 |
+
full_summary = aggregator.infer_aggregate(content_batches)
|
103 |
+
|
104 |
+
# if a path that exists is provided, save the summary with markdown formatting
|
105 |
+
if out_path:
|
106 |
+
out_path = Path(out_path)
|
107 |
+
|
108 |
+
try:
|
109 |
+
with open(out_path, "a", encoding="utf-8") as f:
|
110 |
+
f.write("\n\n### Aggregate Summary\n\n")
|
111 |
+
f.write(
|
112 |
+
"- This is an instruction-based LLM aggregation of the previous 'summary batches'.\n"
|
113 |
+
)
|
114 |
+
f.write(f"- Aggregation model: {aggregator.model_name}\n\n")
|
115 |
+
f.write(f"{full_summary}\n\n")
|
116 |
+
logging.info(f"Updated {out_path} with aggregate summary")
|
117 |
+
except Exception as e:
|
118 |
+
logging.error(f"unable to update {out_path} with aggregate summary: {e}")
|
119 |
+
|
120 |
+
full_summary_html = f"""
|
121 |
+
<div style="
|
122 |
+
margin-bottom: 20px;
|
123 |
+
font-size: 18px;
|
124 |
+
line-height: 1.5em;
|
125 |
+
color: #333;
|
126 |
+
">
|
127 |
+
<h2 style="font-size: 22px; color: #555;">Aggregate Summary:</h2>
|
128 |
+
<p style="white-space: pre-line;">{full_summary}</p>
|
129 |
+
</div>
|
130 |
+
"""
|
131 |
+
return full_summary_html
|
132 |
+
|
133 |
|
134 |
def predict(
|
135 |
input_text: str,
|
|
|
197 |
str in HTML format, string of the summary, str of score
|
198 |
"""
|
199 |
|
200 |
+
remove_stagnant_files() # clean up old files
|
201 |
settings = {
|
202 |
"length_penalty": float(length_penalty),
|
203 |
"repetition_penalty": float(repetition_penalty),
|
|
|
278 |
# save to file
|
279 |
settings["model_name"] = model_name
|
280 |
saved_file = saves_summary(summarize_output=_summaries, outpath=None, **settings)
|
|
|
281 |
return html, full_summary, scores_out, saved_file
|
282 |
|
283 |
|
|
|
430 |
summarize_button = gr.Button(
|
431 |
"Summarize!",
|
432 |
variant="primary",
|
433 |
+
) # TODO: collapse button to be on same line as something else
|
434 |
output_text = gr.HTML("<p><em>Output will appear below:</em></p>")
|
435 |
with gr.Column():
|
436 |
gr.Markdown("#### Results & Scores")
|
|
|
453 |
label="Summary Scores",
|
454 |
placeholder="Summary scores will appear here",
|
455 |
)
|
456 |
+
with gr.Column():
|
457 |
+
gr.Markdown("#### **Summary Output**")
|
458 |
+
summary_text = gr.HTML(
|
459 |
+
label="Summary", value="<i>Summary will appear here!</i>"
|
460 |
+
)
|
461 |
+
with gr.Column():
|
462 |
+
gr.Markdown("##### **Aggregate Summary Batches**")
|
463 |
+
aggregate_button = gr.Button(
|
464 |
+
"Aggregate!",
|
465 |
+
variant="primary",
|
466 |
+
) # TODO: collapse button to be on same line as something else
|
467 |
+
aggregated_summary = gr.HTML(label="Aggregate Summary", value="")
|
468 |
|
|
|
|
|
|
|
|
|
469 |
gr.Markdown("---")
|
470 |
with gr.Column():
|
471 |
gr.Markdown("### Advanced Settings")
|
|
|
533 |
],
|
534 |
outputs=[output_text, summary_text, summary_scores, text_file],
|
535 |
)
|
536 |
+
aggregate_button.click(
|
537 |
+
fn=aggregate_text,
|
538 |
+
inputs=[summary_text, text_file],
|
539 |
+
outputs=[aggregated_summary],
|
540 |
+
)
|
541 |
+
demo.launch(enable_queue=True, share=True)
|