jishnuprakash commited on
Commit
88afd90
1 Parent(s): 44340f9
Files changed (1) hide show
  1. app.py +152 -0
app.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @author:jishnuprakash
3
+ """
4
+ import os
5
+ import torch
6
+ import spacy
7
+ import utils as ut
8
+ import streamlit as st
9
+ import pandas as pd
10
+ import plotly.express as px
11
+ import plotly.graph_objects as go
12
+ import pandas as pd
13
+ import numpy as np
14
+ import matplotlib.pyplot as plt
15
+ import seaborn as sns
16
+ from nltk import word_tokenize
17
+ from nltk.probability import FreqDist
18
+ from matplotlib import pyplot as plt
19
+ from nltk.corpus import stopwords
20
+ from tqdm.auto import tqdm
21
+ from datasets import load_dataset
22
+ from transformers import AutoTokenizer, AutoModel
23
+ from pytorch_lightning.metrics.functional import accuracy, f1, auroc
24
+ from sklearn.metrics import classification_report
25
+
26
+
27
+ st.set_page_config(page_title='NLP Challenge- JP', layout='wide', page_icon=':computer:')
28
+ st.set_option('deprecation.showPyplotGlobalUse', False)
29
+
30
+ #this is the header
31
+ st.markdown("<h1 style='text-align: center; color: black;'>NLP Challenge - HM Land Registry</h1>", unsafe_allow_html=True)
32
+ st.markdown("<h3 style='text-align: center; color: grey;'>Multi-label classification using BERT Transformers</h3>", unsafe_allow_html=True)
33
+ st.markdown("<div style='text-align: center''> Submission by: Jishnu Prakash Kunnanath Poduvattil | Portfolio:<a href='https://jishnuprakash.github.io/'>jishnuprakash.github.io</a> | Source Code: <a href='https://github.com/Jishnuprakash/lexGLUE_jishnuprakash'>Github</a> </div>", unsafe_allow_html=True)
34
+ st.text('')
35
+ expander = st.expander("View Description")
36
+ expander.write("""This is minimal user interface implemetation to view and interact with
37
+ results obtained from fine-tuned BERT transformers trained on LEX GLUE: ECTHR_A dataset.
38
+ Try inputing a text below and see the model predictions. You can also extract the location
39
+ and Date entities from the text using the checkbox.\\
40
+ Below, you can do the same on test data. """)
41
+
42
+
43
+ #Load trained model
44
+ @st.cache(allow_output_mutation=True)
45
+ def load_model():
46
+ trained_model = ut.LexGlueTagger.load_from_checkpoint(ut.check_filename+'.ckpt', num_classes = ut.num_classes)
47
+ #Initialise BERT tokenizer
48
+ tokenizer = AutoTokenizer.from_pretrained(ut.bert_model)
49
+ #Set to Eval and freeze to avoid weight update
50
+ trained_model.eval()
51
+ trained_model.freeze()
52
+ test = load_dataset("lex_glue", "ecthr_a")['test']
53
+ test = ut.preprocess_data(pd.DataFrame(test))
54
+ #Load Model from Spacy
55
+ NER = spacy.load("en_core_web_sm")
56
+ return (trained_model, tokenizer, test, NER)
57
+
58
+ trained_model, tokenizer, test, NER = load_model()
59
+
60
+ st.header("Try out a text!")
61
+ with st.form('model_prediction'):
62
+ text = st.text_area("Input Text", " ".join(test.iloc[0]['text'])[:1525])
63
+ n1, n2, n3 = st.columns((0.13,0.3,0.4))
64
+ ner_check = n1.checkbox("Extract Location and Date", value=True)
65
+ predict = n2.form_submit_button("Predict")
66
+ with st.spinner("Predicting..."):
67
+ if predict:
68
+ encoding = tokenizer.encode_plus(text,
69
+ add_special_tokens=True,
70
+ max_length=512,
71
+ return_token_type_ids=False,
72
+ padding="max_length",
73
+ return_attention_mask=True,
74
+ return_tensors='pt',)
75
+ # Predict on text
76
+ _, prediction = trained_model(encoding["input_ids"], encoding["attention_mask"])
77
+ prediction = list(prediction.flatten().numpy())
78
+
79
+ final_predictions = [prediction.index(i) for i in prediction if i > ut.threshold]
80
+ if len(final_predictions)>0:
81
+ for i in final_predictions:
82
+ st.write('Violations: '+ ut.lex_classes[i] + ' : ' + str(round(prediction[i]*100, 2)) + ' %')
83
+ else:
84
+ st.write("Confidence less than 50%, Please try another text.")
85
+
86
+ if ner_check:
87
+ #Perform NER on a single text
88
+ n_text = NER(text)
89
+ loc = ''
90
+ date = ''
91
+ for word in n_text.ents:
92
+ print(word.text,word.label_)
93
+ if word.label_ == 'DATE':
94
+ date += word.text + ', '
95
+ elif word.label_ == 'GPE':
96
+ loc += word.text + ', '
97
+ loc = "None found" if len(loc)<1 else loc
98
+ date = "None found" if len(date)<1 else date
99
+ st.write("Location entities: " + loc)
100
+ st.write("Date entities: " + date)
101
+
102
+ st.header("Predict on test data")
103
+ with st.form('model_test_prediction'):
104
+ s1, s2, s3 = st.columns((0.1, 0.3, 0.6))
105
+ top = s1.number_input("Count",1, len(test), value=10)
106
+ ner_check2 = s2.checkbox("Extract Location and Date", value=True)
107
+ predict2 = s2.form_submit_button("Predict")
108
+ with st.spinner("Predicting on test data"):
109
+ if predict2:
110
+ test_dataset = ut.LexGlueDataset(test.head(top), tokenizer, max_tokens=512)
111
+
112
+ # Predict on test data
113
+ predictions = []
114
+ labels = []
115
+
116
+ for item in tqdm(test_dataset):
117
+ _ , prediction = trained_model(item["input_ids"].unsqueeze(dim=0),
118
+ item["attention_mask"].unsqueeze(dim=0))
119
+ predictions.append(prediction.flatten())
120
+ labels.append(item["labels"].int())
121
+
122
+ predictions = torch.stack(predictions)
123
+ labels = torch.stack(labels)
124
+
125
+ y_pred = predictions.numpy()
126
+ y_true = labels.numpy()
127
+
128
+ #Filter predictions
129
+ upper, lower = 1, 0
130
+ y_pred = np.where(y_pred > ut.threshold, upper, lower)
131
+ # d1, d2 = st.columns((0.6, 0.4))
132
+
133
+ #Accuracy
134
+ acc = round(float(accuracy(predictions, labels, threshold=ut.threshold))*100, 2)
135
+
136
+ out = test_dataset.data
137
+ out['predictions'] = [[list(i).index(j) for j in i if j==1] for i in y_pred]
138
+ out['labels'] = out['labels'].apply(lambda x: [ut.lex_classes[i] for i in x])
139
+ out['predictions'] = out['predictions'].apply(lambda x: [ut.lex_classes[i] for i in x])
140
+
141
+ if ner_check2:
142
+ #Perform NER on Test Dataset
143
+ out['nlp_text'] = out.text.apply(lambda x: NER(" ".join(x)))
144
+
145
+ #Extract Entities
146
+ out['location'] = out.nlp_text.apply(lambda x: set([i.text for i in x.ents if i.label_=='GPE']))
147
+ out['date'] = out.nlp_text.apply(lambda x: set([i.text for i in x.ents if i.label_=='DATE']))
148
+
149
+ st.dataframe(out.drop('nlp_text', axis=1))
150
+ else:
151
+ st.dataframe(out)
152
+ s3.metric(label ='Accuracy',value = acc, delta = '', delta_color = 'inverse')