Update summarize.py
Browse files- summarize.py +8 -2
summarize.py
CHANGED
@@ -116,15 +116,21 @@ def summarize_via_tokenbatches(
|
|
116 |
pbar = tqdm(total=len(in_id_arr))
|
117 |
|
118 |
for _id, _mask in zip(in_id_arr, att_arr):
|
119 |
-
|
|
|
|
|
|
|
|
|
|
|
120 |
result = summarize(
|
121 |
ids=_id,
|
122 |
mask=_mask,
|
123 |
model=model,
|
|
|
124 |
tokenizer=tokenizer,
|
125 |
**kwargs,
|
126 |
)
|
127 |
-
rate = round(float(len()), 3)
|
128 |
_sum = {
|
129 |
"input_tokens": _id,
|
130 |
"summary": result,
|
|
|
116 |
pbar = tqdm(total=len(in_id_arr))
|
117 |
|
118 |
for _id, _mask in zip(in_id_arr, att_arr):
|
119 |
+
|
120 |
+
if model=='Blaise-g/led_pubmed_sumpubmed_1' or model=='Blaise-g/led_large_sumpbumed_scitldr':
|
121 |
+
model_arch = 'LED'
|
122 |
+
else:
|
123 |
+
model_arch = 'LongT5'
|
124 |
+
|
125 |
result = summarize(
|
126 |
ids=_id,
|
127 |
mask=_mask,
|
128 |
model=model,
|
129 |
+
model_arch=model_arch,
|
130 |
tokenizer=tokenizer,
|
131 |
**kwargs,
|
132 |
)
|
133 |
+
rate = round(float((len(input_text)-len(result))/len(input_text)), 3)
|
134 |
_sum = {
|
135 |
"input_tokens": _id,
|
136 |
"summary": result,
|