update for T5
Browse files- app.py +2 -2
- src/abstractive_summarizer.py +3 -5
app.py
CHANGED
@@ -2,8 +2,7 @@ import torch
|
|
2 |
import nltk
|
3 |
import validators
|
4 |
import streamlit as st
|
5 |
-
from
|
6 |
-
from transformers import pipeline, T5Tokenizer, T5ForConditionalGeneration
|
7 |
|
8 |
# local modules
|
9 |
from extractive_summarizer.model_processors import Summarizer
|
@@ -68,6 +67,7 @@ if __name__ == "__main__":
|
|
68 |
text_to_summarize = clean_txt
|
69 |
abs_tokenizer, abs_model = load_abs_model()
|
70 |
if not is_url:
|
|
|
71 |
text_to_summarize = preprocess_text_for_abstractive_summarization(
|
72 |
tokenizer=abs_tokenizer, text=clean_txt
|
73 |
)
|
|
|
2 |
import nltk
|
3 |
import validators
|
4 |
import streamlit as st
|
5 |
+
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
|
|
6 |
|
7 |
# local modules
|
8 |
from extractive_summarizer.model_processors import Summarizer
|
|
|
67 |
text_to_summarize = clean_txt
|
68 |
abs_tokenizer, abs_model = load_abs_model()
|
69 |
if not is_url:
|
70 |
+
# list of chunks
|
71 |
text_to_summarize = preprocess_text_for_abstractive_summarization(
|
72 |
tokenizer=abs_tokenizer, text=clean_txt
|
73 |
)
|
src/abstractive_summarizer.py
CHANGED
@@ -5,13 +5,11 @@ from transformers import T5Tokenizer
|
|
5 |
|
6 |
def abstractive_summarizer(tokenizer, model, text):
|
7 |
# inputs to the model
|
8 |
-
inputs = [
|
9 |
-
tokenizer.encode(f"summarize: {chunk}", return_tensors="pt") for chunk in text
|
10 |
-
]
|
11 |
abs_summarized_text = []
|
12 |
for input in inputs:
|
13 |
-
output = model.generate(
|
14 |
-
tmp_sum = tokenizer.decode(
|
15 |
abs_summarized_text.append(tmp_sum)
|
16 |
|
17 |
abs_summarized_text = " ".join([summ for summ in abs_summarized_text])
|
|
|
5 |
|
6 |
def abstractive_summarizer(tokenizer, model, text):
|
7 |
# inputs to the model
|
8 |
+
inputs = [tokenizer(f"summarize: {chunk}", return_tensors="pt") for chunk in text]
|
|
|
|
|
9 |
abs_summarized_text = []
|
10 |
for input in inputs:
|
11 |
+
output = model.generate(input["input_ids"])
|
12 |
+
tmp_sum = tokenizer.decode(output[0], skip_special_tokens=True)
|
13 |
abs_summarized_text.append(tmp_sum)
|
14 |
|
15 |
abs_summarized_text = " ".join([summ for summ in abs_summarized_text])
|