File size: 7,298 Bytes
88afd90
 
 
3981553
 
 
88afd90
 
 
 
 
 
 
 
 
 
 
 
11dfa1d
88afd90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d334f3e
74b1387
88afd90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d334f3e
 
88afd90
 
 
d334f3e
88afd90
d334f3e
 
 
 
 
 
 
11dfa1d
 
 
 
 
 
88afd90
 
 
74b1387
88afd90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
"""
@author:jishnuprakash
"""
import nltk
nltk.download('stopwords')

import os
import torch
import spacy
import utils as ut
import streamlit as st
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from spacy import displacy
from nltk import word_tokenize
from nltk.probability import FreqDist
from matplotlib import pyplot as plt
from nltk.corpus import stopwords
from tqdm.auto import tqdm
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel
from pytorch_lightning.metrics.functional import accuracy, f1, auroc
from sklearn.metrics import classification_report

st.set_page_config(page_title='NLP Challenge- JP',  layout='wide', page_icon=':computer:')
st.set_option('deprecation.showPyplotGlobalUse', False)

#this is the header
st.markdown("<h1 style='text-align: center; color: black;'>NLP Challenge - HM Land Registry</h1>", unsafe_allow_html=True)
st.markdown("<h3 style='text-align: center; color: grey;'>Multi-label classification using BERT Transformers</h3>", unsafe_allow_html=True)
st.markdown("<div style='text-align: center''> Submission by: Jishnu Prakash Kunnanath Poduvattil | Portfolio:<a href='https://jishnuprakash.github.io/'>jishnuprakash.github.io</a> | Source Code: <a href='https://github.com/Jishnuprakash/lexGLUE_jishnuprakash'>Github</a> </div>", unsafe_allow_html=True)
st.text('')
expander = st.expander("View Description")
expander.write("""This is minimal user interface implemetation to view and interact with 
                  results obtained from fine-tuned BERT transformers trained on LEX GLUE: ECTHR_A dataset.
                  Try inputing a text below and see the model predictions. You can also extract the location
                  and Date entities from the text using the checkbox.\\
                  Below, you can do the same on test data. """)


#Load trained model
@st.cache(allow_output_mutation=True)
def load_model():
    trained_model = ut.LexGlueTagger.load_from_checkpoint(ut.check_filename+'.ckpt', num_classes = ut.num_classes)
    #Initialise BERT tokenizer
    tokenizer = AutoTokenizer.from_pretrained(ut.bert_model)
    #Set to Eval and freeze to avoid weight update
    trained_model.eval()
    trained_model.freeze()
    test = load_dataset("lex_glue", "ecthr_a")['test']
    test = ut.preprocess_data(pd.DataFrame(test))
    #Load Model from Spacy
    NER = spacy.load("en_core_web_sm")
    return (trained_model, tokenizer, test, NER)

trained_model, tokenizer, test, NER = load_model()

st.header("Try out a text!")
with st.form('model_prediction'):
    text = st.text_area("Input Text", test.iloc[0]['text'][20])
    n1, n2, n3 = st.columns((0.2,0.4,0.4))
    ner_check = n1.checkbox("Extract Location and Date", value=True)
    predict = n2.form_submit_button("Predict")
    with st.spinner("Predicting..."):
        if predict:
            encoding = tokenizer.encode_plus(text,
                                        add_special_tokens=True,
                                        max_length=512,
                                        return_token_type_ids=False,
                                        padding="max_length",
                                        return_attention_mask=True,
                                        return_tensors='pt',)
            # Predict on text
            _, prediction = trained_model(encoding["input_ids"], encoding["attention_mask"])
            prediction = list(prediction.flatten().numpy())

            final_predictions = [prediction.index(i) for i in prediction if i > ut.threshold]
            if len(final_predictions)>0:
                for i in final_predictions:
                    st.write('Violations: '+ ut.lex_classes[i] + ' : ' + str(round(prediction[i]*100, 2)) + ' %')
            else:
                st.write("Confidence less than 50%, Please try another text.")
            
            if ner_check:
                #Perform NER on a single text 
                n_text = NER(text)
                loc = []
                date = []
                for word in n_text.ents:
                    print(word.text,word.label_)
                    if word.label_ == 'DATE':
                        date.append(word.text)
                    elif word.label_ == 'GPE':
                        loc.append(word.text)
                loc = list(set(loc))
                date = list(set(date))
                loc = "None found" if len(loc)==0 else loc
                date = "None found" if len(date)==0 else date
                st.write("Location entities: " + ",".join(loc))
                st.write("Date entities: " + ",".join(date))
                
                #Display entities
                st.write("All Entities-")
                ent_html = displacy.render(n_text, style="ent", jupyter=False)
                # Display the entity visualization in the browser:
                st.markdown(ent_html, unsafe_allow_html=True)

st.header("Predict on test data")
with st.form('model_test_prediction'):
    s1, s2, s3 = st.columns((0.2, 0.4, 0.4))
    top = s1.number_input("Count",1, len(test), value=10)
    ner_check2 = s2.checkbox("Extract Location and Date", value=True)
    predict2 = s2.form_submit_button("Predict")
    with st.spinner("Predicting on test data"):
        if predict2:
            test_dataset = ut.LexGlueDataset(test.head(top), tokenizer, max_tokens=512)

            # Predict on test data
            predictions = []
            labels = []

            for item in tqdm(test_dataset):
                _ , prediction = trained_model(item["input_ids"].unsqueeze(dim=0),
                                    item["attention_mask"].unsqueeze(dim=0))
                predictions.append(prediction.flatten())
                labels.append(item["labels"].int())

            predictions = torch.stack(predictions)
            labels = torch.stack(labels)

            y_pred = predictions.numpy()
            y_true = labels.numpy()

            #Filter predictions
            upper, lower = 1, 0
            y_pred = np.where(y_pred > ut.threshold, upper, lower)
            # d1, d2 = st.columns((0.6, 0.4))

            #Accuracy
            acc = round(float(accuracy(predictions, labels, threshold=ut.threshold))*100, 2)

            out = test_dataset.data
            out['predictions'] = [[list(i).index(j) for j in i if j==1] for i in y_pred]
            out['labels'] = out['labels'].apply(lambda x: [ut.lex_classes[i] for i in x])
            out['predictions'] = out['predictions'].apply(lambda x: [ut.lex_classes[i] for i in x])

            if ner_check2:
                #Perform NER on Test Dataset
                out['nlp_text'] = out.text.apply(lambda x: NER(" ".join(x)))
                
                #Extract Entities
                out['location'] = out.nlp_text.apply(lambda x: set([i.text for i in x.ents if i.label_=='GPE']))
                out['date'] = out.nlp_text.apply(lambda x: set([i.text for i in x.ents if i.label_=='DATE']))

                st.dataframe(out.drop('nlp_text', axis=1))
            else:
                st.dataframe(out)
            s3.metric(label ='Accuracy',value = acc, delta = '', delta_color = 'inverse')