Create new file
Browse files
app.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from predict import run_prediction
|
3 |
+
from io import StringIO
|
4 |
+
import json
|
5 |
+
import spacy
|
6 |
+
from spacy import displacy
|
7 |
+
from transformers import AutoTokenizer, AutoModelForTokenClassification,RobertaTokenizer,pipeline
|
8 |
+
import torch
|
9 |
+
import nltk
|
10 |
+
from nltk.tokenize import sent_tokenize
|
11 |
+
from fin_readability_sustainability import BERTClass, do_predict
|
12 |
+
import pandas as pd
|
13 |
+
|
14 |
+
nltk.download('punkt')
|
15 |
+
nlp = spacy.load("en_core_web_sm")
|
16 |
+
|
17 |
+
st.set_page_config(layout="wide")
|
18 |
+
st.cache(show_spinner=False, persist=True)
|
19 |
+
|
20 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
21 |
+
#SUSTAIN STARTS
|
22 |
+
tokenizer_sus = RobertaTokenizer.from_pretrained('roberta-base')
|
23 |
+
model_sustain = BERTClass(2, "sustanability")
|
24 |
+
model_sustain.to(device)
|
25 |
+
model_sustain.load_state_dict(torch.load('sustainability_model.bin', map_location=device)['model_state_dict'])
|
26 |
+
|
27 |
+
|
28 |
+
def get_sustainability(text):
|
29 |
+
df = pd.DataFrame({'sentence':sent_tokenize(text)})
|
30 |
+
actual_predictions_sustainability = do_predict(model_sustain, tokenizer_sus, df)
|
31 |
+
highlight = []
|
32 |
+
for sent, prob in zip(df['sentence'].values, actual_predictions_sustainability[1]):
|
33 |
+
if prob>=4.384316:
|
34 |
+
highlight.append((sent, 'non-sustainable'))
|
35 |
+
elif prob<=1.423736:
|
36 |
+
highlight.append((sent, 'sustainable'))
|
37 |
+
else:
|
38 |
+
highlight.append((sent, '-'))
|
39 |
+
return highlight
|
40 |
+
|
41 |
+
#SUSTAIN ENDS
|
42 |
+
|
43 |
+
##Summarization
|
44 |
+
def summarize_text(text):
|
45 |
+
summarizer = pipeline("summarization", model="knkarthick/MEETING_SUMMARY")
|
46 |
+
resp = summarizer(text)
|
47 |
+
stext = resp[0]['summary_text']
|
48 |
+
return stext
|
49 |
+
|
50 |
+
##Forward Looking Statement
|
51 |
+
#def fls(text):
|
52 |
+
# fls_model = pipeline("text-classification", model="yiyanghkust/finbert-fls", tokenizer="yiyanghkust/finbert-fls")
|
53 |
+
# results = fls_model(split_in_sentences(text))
|
54 |
+
#return make_spans(text,results)
|
55 |
+
|
56 |
+
##Company Extraction
|
57 |
+
#ner=pipeline('ner',model='Jean-Baptiste/camembert-ner-with-dates',tokenizer='Jean-Baptiste/camembert-ner-with-dates', aggregation_strategy="simple")
|
58 |
+
#def fin_ner(text):
|
59 |
+
#replaced_spans = ner(text)
|
60 |
+
# return replaced_spans
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
|
65 |
+
def load_questions():
|
66 |
+
questions = []
|
67 |
+
with open('questions.txt') as f:
|
68 |
+
questions = f.readlines()
|
69 |
+
return questions
|
70 |
+
|
71 |
+
|
72 |
+
def load_questions_short():
|
73 |
+
questions_short = []
|
74 |
+
with open('questionshort.txt') as f:
|
75 |
+
questions_short = f.readlines()
|
76 |
+
return questions_short
|
77 |
+
|
78 |
+
|
79 |
+
st.cache(show_spinner=False, persist=True)
|
80 |
+
|
81 |
+
|
82 |
+
questions = load_questions()
|
83 |
+
questions_short = load_questions_short()
|
84 |
+
|
85 |
+
### DEFINE SIDEBAR
|
86 |
+
st.sidebar.title("Interactive Contract Analysis")
|
87 |
+
|
88 |
+
st.sidebar.header('CONTRACT UPLOAD')
|
89 |
+
|
90 |
+
# upload contract
|
91 |
+
user_upload = st.sidebar.file_uploader('Please upload your contract', type=['txt'],
|
92 |
+
accept_multiple_files=False)
|
93 |
+
|
94 |
+
|
95 |
+
# process upload
|
96 |
+
if user_upload is not None:
|
97 |
+
print(user_upload.name, user_upload.type)
|
98 |
+
extension = user_upload.name.split('.')[-1].lower()
|
99 |
+
if extension == 'txt':
|
100 |
+
print('text file uploaded')
|
101 |
+
# To convert to a string based IO:
|
102 |
+
stringio = StringIO(user_upload.getvalue().decode("utf-8"))
|
103 |
+
|
104 |
+
# To read file as string:
|
105 |
+
contract_data = stringio.read()
|
106 |
+
else:
|
107 |
+
st.warning('Unknown uploaded file type, please try again')
|
108 |
+
|
109 |
+
results_drop = ['1', '2', '3']
|
110 |
+
number_results = st.sidebar.selectbox('Select number of results', results_drop)
|
111 |
+
|
112 |
+
### DEFINE MAIN PAGE
|
113 |
+
st.header("Legal Contract Review Demo")
|
114 |
+
paragraph = st.text_area(label="Contract", value=contract_data, height=300)
|
115 |
+
|
116 |
+
questions_drop = questions_short
|
117 |
+
question_short = st.selectbox('Choose one of the 41 queries from the CUAD dataset:', questions_drop)
|
118 |
+
idxq = questions_drop.index(question_short)
|
119 |
+
question = questions[idxq]
|
120 |
+
|
121 |
+
|
122 |
+
raw_answer=""
|
123 |
+
if st.button('Analyze'):
|
124 |
+
if (not len(paragraph)==0) and not (len(question)==0):
|
125 |
+
print('getting predictions')
|
126 |
+
with st.spinner(text='Analysis in progress...'):
|
127 |
+
predictions = run_prediction([question], paragraph, 'marshmellow77/roberta-base-cuad',
|
128 |
+
n_best_size=5)
|
129 |
+
answer = ""
|
130 |
+
if predictions['0'] == "":
|
131 |
+
answer = 'No answer found in document'
|
132 |
+
else:
|
133 |
+
# if number_results == '1':
|
134 |
+
# answer = f"Answer: {predictions['0']}"
|
135 |
+
# # st.text_area(label="Answer", value=f"{answer}")
|
136 |
+
# else:
|
137 |
+
answer = ""
|
138 |
+
with open("nbest.json") as jf:
|
139 |
+
data = json.load(jf)
|
140 |
+
for i in range(int(number_results)):
|
141 |
+
raw_answer=data['0'][i]['text']
|
142 |
+
answer += f"Answer {i+1}: {data['0'][i]['text']} -- \n"
|
143 |
+
answer += f"Probability: {round(data['0'][i]['probability']*100,1)}%\n\n"
|
144 |
+
st.success(answer)
|
145 |
+
|
146 |
+
else:
|
147 |
+
st.write("Unable to call model, please select question and contract")
|
148 |
+
|
149 |
+
if st.button('Check Sustainability'):
|
150 |
+
if(raw_answer==""):
|
151 |
+
st.write("Unable to call model, please select question and contract")
|
152 |
+
else:
|
153 |
+
st.write(get_sustainability(raw_answer))
|
154 |
+
if st.button('Summarize'):
|
155 |
+
if(raw_answer==""):
|
156 |
+
st.write("Unable to call model, please select question and contract")
|
157 |
+
else:
|
158 |
+
st.write(summarize_text(raw_answer))
|
159 |
+
|
160 |
+
if st.button('NER'):
|
161 |
+
if(raw_answer==""):
|
162 |
+
st.write("Unable to call model, please select question and contract")
|
163 |
+
else:
|
164 |
+
doc = nlp(raw_answer)
|
165 |
+
st.write(displacy.render(doc, style="ent"))
|