|
import torch |
|
import nltk |
|
import validators |
|
import streamlit as st |
|
from transformers import pipeline, T5Tokenizer |
|
|
|
|
|
from extractive_summarizer.model_processors import Summarizer |
|
from src.utils import clean_text, fetch_article_text |
|
from src.abstractive_summarizer import ( |
|
preprocess_text_for_abstractive_summarization, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
st.title("Text Summarizer 📝") |
|
summarize_type = st.sidebar.selectbox( |
|
"Summarization type", options=["Extractive", "Abstractive"] |
|
) |
|
|
|
|
|
nltk.download("punkt") |
|
abs_tokenizer_name = "facebook/bart-large-cnn" |
|
abs_model_name = "facebook/bart-large-cnn" |
|
abs_tokenizer = T5Tokenizer.from_pretrained(abs_tokenizer_name) |
|
abs_max_length = 80 |
|
abs_min_length = 30 |
|
|
|
|
|
inp_text = st.text_input("Enter text or a url here") |
|
|
|
is_url = validators.url(inp_text) |
|
if is_url: |
|
|
|
text, clean_txt = fetch_article_text(url=inp_text) |
|
else: |
|
clean_txt = clean_text(inp_text) |
|
|
|
|
|
with st.expander("View input text"): |
|
if is_url: |
|
st.write(clean_txt[0]) |
|
else: |
|
st.write(clean_txt) |
|
summarize = st.button("Summarize") |
|
|
|
|
|
if summarize: |
|
if summarize_type == "Extractive": |
|
if is_url: |
|
text_to_summarize = " ".join([txt for txt in clean_txt]) |
|
else: |
|
text_to_summarize = clean_txt |
|
|
|
|
|
with st.spinner( |
|
text="Creating extractive summary. This might take a few seconds ..." |
|
): |
|
ext_model = Summarizer() |
|
summarized_text = ext_model(text_to_summarize, num_sentences=6) |
|
|
|
elif summarize_type == "Abstractive": |
|
with st.spinner( |
|
text="Creating abstractive summary. This might take a few seconds ..." |
|
): |
|
text_to_summarize = clean_txt |
|
abs_summarizer = pipeline( |
|
"summarization", model=abs_model_name, tokenizer=abs_tokenizer_name |
|
) |
|
if is_url is False: |
|
|
|
text_to_summarize = preprocess_text_for_abstractive_summarization( |
|
tokenizer=abs_tokenizer, text=clean_txt |
|
) |
|
print(text_to_summarize) |
|
tmp_sum = abs_summarizer( |
|
text_to_summarize, |
|
max_length=abs_max_length, |
|
min_length=abs_min_length, |
|
do_sample=False, |
|
) |
|
|
|
summarized_text = " ".join([summ["summary_text"] for summ in tmp_sum]) |
|
|
|
|
|
st.subheader("Summarized text") |
|
st.info(summarized_text) |
|
|