E2E-QA-mining / app.py
jian-mo's picture
Update app.py
a3df176
import streamlit as st
from pandas import DataFrame
import seaborn as sns
# For download buttons
# from functionforDownloadButtons import download_button
from transformers import pipeline
st.set_page_config(
page_title="E2E QA MINING",
page_icon="?",
)
def _max_width_():
max_width_str = f"max-width: 1400px;"
st.markdown(
f"""
<style>
.reportview-container .main .block-container{{
{max_width_str}
}}
</style>
""",
unsafe_allow_html=True,
)
_max_width_()
c30, c31, c32 = st.columns([2.5, 1, 3])
with c30:
# st.image("logo.png", width=400)
st.title("πŸ”‘ E2E QA MINING")
st.header("")
with st.expander("ℹ️ - About this app", expanded=True):
st.write(
"""
- The *E2E QA MINING* app helps you mine question-answer pairs from a given context.
"""
)
st.markdown("")
st.markdown("")
st.markdown("## **πŸ“Œ Paste document **")
with st.form(key="my_form"):
ce, c1, ce, c2, c3 = st.columns([0.07, 1, 0.07, 5, 0.07])
with c1:
generator = pipeline('text2text-generation', model='mojians/E2E-QA-Mining')
# top_N = st.slider(
# "# of results",
# min_value=1,
# max_value=30,
# value=10,
# help="You can choose the number of keywords/keyphrases to display. Between 1 and 30, default number is 10.",
# )
min_length = st.slider(
"min_length",
value=50,
min_value=20,
max_value=50,
help="""The minimum value for the generated text""",
# help="Minimum value for the keyphrase_ngram_range. keyphrase_ngram_range sets the length of the resulting keywords/keyphrases. To extract keyphrases, simply set keyphrase_ngram_range to (1, # 2) or higher depending on the number of words you would like in the resulting keyphrases.",
)
max_length = st.slider(
"max length",
value=500,
min_value=50,
max_value=500,
help="""The maximum value for the generated text.""",
)
do_sample = st.checkbox(
"do_sample",
value=True,
help="Tick this box to enable sampling",
)
with c2:
doc = st.text_area(
"Paste your text below (max 512 words)",
height=510,
value='''The COVID-19 pandemic, also known as the coronavirus pandemic, is an ongoing global pandemic of coronavirus disease 2019 (COVID-19) caused by severe acute respiratory syndrome coronavirus 2 (SARS-CoV-2). The novel virus was first identified from an outbreak in Wuhan, China in December 2019. Attempts to contain it there failed, allowing the virus to spread worldwide. The World Health Organization (WHO) declared a Public Health Emergency of International Concern on 30 January 2020 and a pandemic on 11 March 2020. As of 2 April 2022, the pandemic had caused more than 489 million cases and 6.14 million deaths, making it one of the deadliest in history.
COVID-19 symptoms range from undetectable to deadly, but most commonly include fever, dry cough, and fatigue. Severe illness is more likely in elderly patients and those with certain underlying medical conditions. COVID‑19 transmits when people breathe in air contaminated by droplets and small airborne particles containing the virus. The risk of breathing these in is highest when people are in close proximity, but they can be inhaled over longer distances, particularly indoors. Transmission can also occur if contaminated fluids reach the eyes, nose or mouth, and, rarely, via contaminated surfaces. Infected persons are typically contagious for 10 days, and can spread the virus even if they do not develop symptoms. Mutations have produced many strains (variants) with varying degrees of infectivity and virulence. ''')
MAX_WORDS = 512
import re
res = len(re.findall(r"\w+", doc))
if res > MAX_WORDS:
st.warning(
"⚠️ Your text contains "
+ str(res)
+ " words."
+ " Only the first "+MAX_WORDS+" words will be reviewed. Stay tuned as increased allowance is coming! 😊"
)
doc = doc[:MAX_WORDS]
submit_button = st.form_submit_button(label="✨ Get me the data!")
if not submit_button:
st.stop()
results = generator("context:"+doc+ " generate questions and answers:", do_sample=do_sample, min_length=min_length,max_length=max_length)[0]['generated_text']
st.markdown("## **🎈 Check results **")
st.header("")
# cs, c1, c2, c3, cLast = st.columns([2, 1.5, 1.5, 1.5, 2])
#
# with c1:
# CSVButton2 = download_button(keywords, "Data.csv", "πŸ“₯ Download (.csv)")
# with c2:
# CSVButton2 = download_button(keywords, "Data.txt", "πŸ“₯ Download (.txt)")
# with c3:
# CSVButton2 = download_button(keywords, "Data.json", "πŸ“₯ Download (.json)")
st.header("")
# df = (
# DataFrame(results, columns=["generated_text"])
# .reset_index(drop=True)
# )
#
# df.index += 1
#
# # Add styling
# cmGreen = sns.light_palette("green", as_cmap=True)
# cmRed = sns.light_palette("red", as_cmap=True)
#
c1, c2, c3 = st.columns([1, 3, 1])
# format_dictionary = {
# "Relevancy": "{:.1%}",
# }
# df = df.format(format_dictionary)
with c2:
text_output=st.empty()
text_output.markdown(" ".join(["<ul>", " ".join([f'<li>{i}</li>' for i in results.split("<sep>")]),"</ul>"]), unsafe_allow_html=True)