Spaces:
Runtime error
Runtime error
import streamlit as st | |
import pandas as pd | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline | |
from annotated_text import annotated_text | |
st.set_page_config( | |
page_title="Requirement Identifier", layout="wide" | |
#page_icon="🎈", | |
) | |
st.cache_resource | |
def get_pipleine(): | |
tokenizer = AutoTokenizer.from_pretrained("SarmadBashir/dronology_bert_uncased", model_max_length = 128) | |
model = AutoModelForSequenceClassification.from_pretrained("SarmadBashir/dronology_bert_uncased") | |
pipe = pipeline("text-classification", model=model, tokenizer=tokenizer) | |
return pipe | |
pipe = get_pipleine() | |
def get_prediction(input_text, pipe): | |
map_labels = {'LABEL_0': 'Information', 'LABEL_1': 'Requirement'} | |
output = pipe(input_text) | |
label = map_labels.get(output[0]['label']) | |
score = int(round(output[0]['score'], 2) * 100) | |
return label, score | |
def _max_width_(): | |
max_width_str = f"max-width: 1400px;" | |
st.markdown( | |
f""" | |
<style> | |
.reportview-container .main .block-container{{ | |
{max_width_str} | |
}} | |
</style> | |
""", | |
unsafe_allow_html=True, | |
) | |
#_max_width_() | |
def show_updated_list(test_data): | |
show_list = [] | |
text = test_data['STR.REQ'].tolist() | |
labels = test_data['class'].tolist() | |
for info in zip(text, labels): | |
if info[1] == 0: | |
updated_text = info[0] + ' (GT: Information)' | |
else: | |
updated_text = info[0] + ' (GT: Requirement)' | |
show_list.append(updated_text) | |
return show_list | |
c30, c31, c32 = st.columns([6, 1, 3]) | |
with c30: | |
# st.image("logo.png", width=400) | |
st.title("Requirement or not, that is the question!") | |
#st.header("") | |
with st.expander("ℹ️ - About this app", expanded=True): | |
st.write( | |
""" | |
- This app is a working demo for the paper: "Requirement or not, that is the question: A case from the railway industry". | |
- The app relies on BERT model (trained on dronology dataset) for classification of given text as Requirement or Information. | |
- The replication package [ReqORNot](https://github.com/a66as/REFSQ2023-ReqORNot) contains all the relevant code and information for replication of experiements. | |
""" | |
) | |
st.write(""" | |
- This work is partially funded by [AIDOaRt (KDT)](https://sites.mdu.se/aidoart) and [SmartDelta (ITEA)](https://itea4.org/project/smartdelta.html) projects. | |
""") | |
images= ['./pictures/smart_delta.jpeg', './pictures/aidoart.jpeg' ] | |
#st.image(images, caption=None, width=130, use_column_width=False, clamp=False, channels="RGB", output_format="auto") | |
#st.image('./pictures/smart_delta.jpeg', caption=None, width=130, use_column_width=None, clamp=False, channels="RGB", output_format="auto") | |
col1,col2 = st.columns([0.7,7]) | |
with col1: | |
st.image(images[0],width=100,use_column_width='never') | |
with col2: | |
st.image(images[1],width=100,use_column_width='never') | |
st.markdown("") | |
st.markdown("") | |
st.markdown("") | |
st.markdown("") | |
with st.form(key="my_form"): | |
ce, c1, c2, ce = st.columns([0.04, 0.01, 1, 0.04]) | |
with c2: | |
doc = st.text_area( | |
"✍️Write the text below", | |
height=50, | |
) | |
st.write("***OR***") | |
test_data = pd.read_csv('./data/test.csv') | |
#test_data = test_data.sample(frac=1) | |
text = show_updated_list(test_data) | |
text.insert(0, '-') | |
option = st.selectbox( | |
'Select the row from test data', | |
text) | |
#print(option) | |
st.markdown("") | |
submit_button = st.form_submit_button(label="✨ Get the Prediction!") | |
if not submit_button: | |
st.stop() | |
if option == '-' and len(doc.split()) <= 2: | |
st.warning('Please Provide Valid Input!') | |
st.stop() | |
st.markdown("") | |
st.markdown("") | |
st.markdown("### **Output**") | |
if option != '-' and len(doc.split())<=3: | |
option_text = '' | |
if '(GT: Requirement)' in option: | |
option_text = option.replace('(GT: Requirement)', '').strip() | |
else: | |
option_text = option.replace('(GT: Information)', '').strip() | |
predicted_label, probability = get_prediction(option_text, pipe) | |
annotated_text( | |
'The model classified', | |
('selected', "", "#8ef"), 'text as: ', | |
(predicted_label,"", "#afa"),'with a probability of', | |
(str(probability),"", "#afa"),'%' | |
) | |
else: | |
predicted_label, probability = get_prediction(doc, pipe) | |
annotated_text( | |
'The model classified', | |
('written', "", "#8ef"), 'text as: ', | |
(predicted_label,"", "#afa"),'with a probability of', | |
(str(probability),"", "#afa"),'%' | |
) |