SarmadBashir's picture
Update app.py
4bd96e0
import streamlit as st
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
from annotated_text import annotated_text
st.set_page_config(
page_title="Requirement Identifier", layout="wide"
#page_icon="🎈",
)
st.cache_resource
def get_pipleine():
tokenizer = AutoTokenizer.from_pretrained("SarmadBashir/dronology_bert_uncased", model_max_length = 128)
model = AutoModelForSequenceClassification.from_pretrained("SarmadBashir/dronology_bert_uncased")
pipe = pipeline("text-classification", model=model, tokenizer=tokenizer)
return pipe
pipe = get_pipleine()
def get_prediction(input_text, pipe):
map_labels = {'LABEL_0': 'Information', 'LABEL_1': 'Requirement'}
output = pipe(input_text)
label = map_labels.get(output[0]['label'])
score = int(round(output[0]['score'], 2) * 100)
return label, score
def _max_width_():
max_width_str = f"max-width: 1400px;"
st.markdown(
f"""
<style>
.reportview-container .main .block-container{{
{max_width_str}
}}
</style>
""",
unsafe_allow_html=True,
)
#_max_width_()
def show_updated_list(test_data):
show_list = []
text = test_data['STR.REQ'].tolist()
labels = test_data['class'].tolist()
for info in zip(text, labels):
if info[1] == 0:
updated_text = info[0] + ' (GT: Information)'
else:
updated_text = info[0] + ' (GT: Requirement)'
show_list.append(updated_text)
return show_list
c30, c31, c32 = st.columns([6, 1, 3])
with c30:
# st.image("logo.png", width=400)
st.title("Requirement or not, that is the question!")
#st.header("")
with st.expander("ℹ️ - About this app", expanded=True):
st.write(
"""
- This app is a working demo for the paper: "Requirement or not, that is the question: A case from the railway industry".
- The app relies on BERT model (trained on dronology dataset) for classification of given text as Requirement or Information.
- The replication package [ReqORNot](https://github.com/a66as/REFSQ2023-ReqORNot) contains all the relevant code and information for replication of experiements.
"""
)
st.write("""
- This work is partially funded by [AIDOaRt (KDT)](https://sites.mdu.se/aidoart) and [SmartDelta (ITEA)](https://itea4.org/project/smartdelta.html) projects.
""")
images= ['./pictures/smart_delta.jpeg', './pictures/aidoart.jpeg' ]
#st.image(images, caption=None, width=130, use_column_width=False, clamp=False, channels="RGB", output_format="auto")
#st.image('./pictures/smart_delta.jpeg', caption=None, width=130, use_column_width=None, clamp=False, channels="RGB", output_format="auto")
col1,col2 = st.columns([0.7,7])
with col1:
st.image(images[0],width=100,use_column_width='never')
with col2:
st.image(images[1],width=100,use_column_width='never')
st.markdown("")
st.markdown("")
st.markdown("")
st.markdown("")
with st.form(key="my_form"):
ce, c1, c2, ce = st.columns([0.04, 0.01, 1, 0.04])
with c2:
doc = st.text_area(
"✍️Write the text below",
height=50,
)
st.write("***OR***")
test_data = pd.read_csv('./data/test.csv')
#test_data = test_data.sample(frac=1)
text = show_updated_list(test_data)
text.insert(0, '-')
option = st.selectbox(
'Select the row from test data',
text)
#print(option)
st.markdown("")
submit_button = st.form_submit_button(label="✨ Get the Prediction!")
if not submit_button:
st.stop()
if option == '-' and len(doc.split()) <= 2:
st.warning('Please Provide Valid Input!')
st.stop()
st.markdown("")
st.markdown("")
st.markdown("### **Output**")
if option != '-' and len(doc.split())<=3:
option_text = ''
if '(GT: Requirement)' in option:
option_text = option.replace('(GT: Requirement)', '').strip()
else:
option_text = option.replace('(GT: Information)', '').strip()
predicted_label, probability = get_prediction(option_text, pipe)
annotated_text(
'The model classified',
('selected', "", "#8ef"), 'text as: ',
(predicted_label,"", "#afa"),'with a probability of',
(str(probability),"", "#afa"),'%'
)
else:
predicted_label, probability = get_prediction(doc, pipe)
annotated_text(
'The model classified',
('written', "", "#8ef"), 'text as: ',
(predicted_label,"", "#afa"),'with a probability of',
(str(probability),"", "#afa"),'%'
)