Spaces:
Build error
Build error
""" | |
@author:jishnuprakash | |
""" | |
import nltk | |
nltk.download('stopwords') | |
import os | |
import torch | |
import spacy | |
import utils as ut | |
import streamlit as st | |
import pandas as pd | |
import pandas as pd | |
import numpy as np | |
from spacy import displacy | |
from tqdm.auto import tqdm | |
from datasets import load_dataset | |
from transformers import AutoTokenizer | |
from pytorch_lightning.metrics.functional import accuracy | |
st.set_page_config(page_title='Classification - BERT', layout='wide', page_icon=':computer:') | |
st.set_option('deprecation.showPyplotGlobalUse', False) | |
#this is the header | |
st.markdown("<h1 style='text-align: center; color: black;'>Multi-label classification using BERT Transformers</h1>", unsafe_allow_html=True) | |
st.markdown("<div style='text-align: center''> Author: Jishnu Prakash Kunnanath Poduvattil | Portfolio:<a href='https://jishnuprakash.github.io/'>jishnuprakash.github.io</a> | Source Code: <a href='https://github.com/Jishnuprakash/lexGLUE_jishnuprakash'>Github</a> </div>", unsafe_allow_html=True) | |
st.text('') | |
expander = st.expander("View Description") | |
expander.write("""This is a user interface to view and interact with | |
results obtained from fine-tuned BERT transformers trained on LEX GLUE: ECTHR_A dataset. | |
Try inputing a text below and see the model predictions. You can also extract the location | |
and Date entities from the text using the checkbox.\\ | |
Below, you can do the same on test data.\\ | |
Please find the test data here https://huggingface.co/datasets/lex_glue """) | |
#Load trained model | |
def load_model(): | |
trained_model = ut.LexGlueTagger.load_from_checkpoint(ut.check_filename+'.ckpt', num_classes = ut.num_classes) | |
#Initialise BERT tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(ut.bert_model) | |
#Set to Eval and freeze to avoid weight update | |
trained_model.eval() | |
trained_model.freeze() | |
test = load_dataset("lex_glue", "ecthr_a")['test'] | |
test = ut.preprocess_data(pd.DataFrame(test)) | |
#Load Model from Spacy | |
NER = spacy.load("en_core_web_sm") | |
return (trained_model, tokenizer, test, NER) | |
trained_model, tokenizer, test, NER = load_model() | |
st.header("Try out a text!") | |
with st.form('model_prediction'): | |
text = st.text_area("Input Text", test.iloc[0]['text'][20]) | |
text = text[:2000] | |
n1, n2, n3 = st.columns((0.2,0.4,0.4)) | |
ner_check = n1.checkbox("Extract Location and Date", value=True) | |
predict = n2.form_submit_button("Predict") | |
with st.spinner("Predicting..."): | |
if predict: | |
encoding = tokenizer.encode_plus(text, | |
add_special_tokens=True, | |
max_length=512, | |
return_token_type_ids=False, | |
padding="max_length", | |
return_attention_mask=True, | |
return_tensors='pt',) | |
# Predict on text | |
_, prediction = trained_model(encoding["input_ids"], encoding["attention_mask"]) | |
prediction = list(prediction.flatten().numpy()) | |
final_predictions = [prediction.index(i) for i in prediction if i > ut.threshold] | |
if len(final_predictions)>0: | |
for i in final_predictions: | |
st.write('Violations: '+ ut.lex_classes[i] + ' : ' + str(round(prediction[i]*100, 2)) + ' %') | |
else: | |
st.write("Confidence less than 50%, Please try another text.") | |
if ner_check: | |
#Perform NER on a single text | |
n_text = NER(text) | |
loc = [] | |
date = [] | |
for word in n_text.ents: | |
print(word.text,word.label_) | |
if word.label_ == 'DATE': | |
date.append(word.text) | |
elif word.label_ == 'GPE': | |
loc.append(word.text) | |
loc = list(set(loc)) | |
date = list(set(date)) | |
loc = ["None found"] if len(loc)==0 else loc | |
date = ["None found"] if len(date)==0 else date | |
st.write("Location entities: " + ",".join(loc)) | |
st.write("Date entities: " + ",".join(date)) | |
#Display entities | |
st.write("All Entities-") | |
ent_html = displacy.render(n_text, style="ent", jupyter=False) | |
# Display the entity visualization in the browser: | |
st.markdown(ent_html, unsafe_allow_html=True) | |
st.header("Predict on test data") | |
with st.form('model_test_prediction'): | |
s1, s2, s3 = st.columns((0.2, 0.4, 0.4)) | |
top = s1.number_input("Count",1, len(test), value=10) | |
ner_check2 = s2.checkbox("Extract Location and Date", value=True) | |
predict2 = s2.form_submit_button("Predict") | |
with st.spinner("Predicting on test data"): | |
if predict2: | |
test_dataset = ut.LexGlueDataset(test.head(top), tokenizer, max_tokens=512) | |
# Predict on test data | |
predictions = [] | |
labels = [] | |
for item in tqdm(test_dataset): | |
_ , prediction = trained_model(item["input_ids"].unsqueeze(dim=0), | |
item["attention_mask"].unsqueeze(dim=0)) | |
predictions.append(prediction.flatten()) | |
labels.append(item["labels"].int()) | |
predictions = torch.stack(predictions) | |
labels = torch.stack(labels) | |
y_pred = predictions.numpy() | |
y_true = labels.numpy() | |
#Filter predictions | |
upper, lower = 1, 0 | |
y_pred = np.where(y_pred > ut.threshold, upper, lower) | |
# d1, d2 = st.columns((0.6, 0.4)) | |
#Accuracy | |
acc = round(float(accuracy(predictions, labels, threshold=ut.threshold))*100, 2) | |
out = test_dataset.data | |
out['predictions'] = [[list(i).index(j) for j in i if j==1] for i in y_pred] | |
out['labels'] = out['labels'].apply(lambda x: [ut.lex_classes[i] for i in x]) | |
out['predictions'] = out['predictions'].apply(lambda x: [ut.lex_classes[i] for i in x]) | |
if ner_check2: | |
#Perform NER on Test Dataset | |
out['nlp_text'] = out.text.apply(lambda x: NER(" ".join(x))) | |
#Extract Entities | |
out['location'] = out.nlp_text.apply(lambda x: set([i.text for i in x.ents if i.label_=='GPE'])) | |
out['date'] = out.nlp_text.apply(lambda x: set([i.text for i in x.ents if i.label_=='DATE'])) | |
st.dataframe(out.drop('nlp_text', axis=1)) | |
else: | |
st.dataframe(out) | |
s3.metric(label ='Accuracy',value = acc, delta = '', delta_color = 'inverse') | |
st.header("Comparison - Model Performance") | |
st.write("""2 transformer models were finetuned and compared their performance on the test dataset. \\ | |
- Bert uncased model (on Original & preprocessed text) \\ | |
- Legal-Bert model (on Original & preprocessed text)\\ | |
(Preprocessing steps were removal of numbers, symbols, stopwords followed by lemmatisation on tokens.)\\ | |
The best performing model is Legal-BERT on original data. Please see the comparison below.""") | |
met = pd.read_csv("model_comparison.csv") | |
a1, a2 = st.columns((0.5,0.5)) | |
a1.subheader("Evaluation Metrics") | |
a1.dataframe(met[12:].reset_index(drop=True)) | |
a2.subheader("Area under ROC curve") | |
a2.dataframe(met[:11]) |