Shredder commited on
Commit
a560ed2
1 Parent(s): 20db3f3

Create new file

Browse files
Files changed (1) hide show
  1. app.py +165 -0
app.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from predict import run_prediction
3
+ from io import StringIO
4
+ import json
5
+ import spacy
6
+ from spacy import displacy
7
+ from transformers import AutoTokenizer, AutoModelForTokenClassification,RobertaTokenizer,pipeline
8
+ import torch
9
+ import nltk
10
+ from nltk.tokenize import sent_tokenize
11
+ from fin_readability_sustainability import BERTClass, do_predict
12
+ import pandas as pd
13
+
14
+ nltk.download('punkt')
15
+ nlp = spacy.load("en_core_web_sm")
16
+
17
+ st.set_page_config(layout="wide")
18
+ st.cache(show_spinner=False, persist=True)
19
+
20
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ #SUSTAIN STARTS
22
+ tokenizer_sus = RobertaTokenizer.from_pretrained('roberta-base')
23
+ model_sustain = BERTClass(2, "sustanability")
24
+ model_sustain.to(device)
25
+ model_sustain.load_state_dict(torch.load('sustainability_model.bin', map_location=device)['model_state_dict'])
26
+
27
+
28
+ def get_sustainability(text):
29
+ df = pd.DataFrame({'sentence':sent_tokenize(text)})
30
+ actual_predictions_sustainability = do_predict(model_sustain, tokenizer_sus, df)
31
+ highlight = []
32
+ for sent, prob in zip(df['sentence'].values, actual_predictions_sustainability[1]):
33
+ if prob>=4.384316:
34
+ highlight.append((sent, 'non-sustainable'))
35
+ elif prob<=1.423736:
36
+ highlight.append((sent, 'sustainable'))
37
+ else:
38
+ highlight.append((sent, '-'))
39
+ return highlight
40
+
41
+ #SUSTAIN ENDS
42
+
43
+ ##Summarization
44
+ def summarize_text(text):
45
+ summarizer = pipeline("summarization", model="knkarthick/MEETING_SUMMARY")
46
+ resp = summarizer(text)
47
+ stext = resp[0]['summary_text']
48
+ return stext
49
+
50
+ ##Forward Looking Statement
51
+ #def fls(text):
52
+ # fls_model = pipeline("text-classification", model="yiyanghkust/finbert-fls", tokenizer="yiyanghkust/finbert-fls")
53
+ # results = fls_model(split_in_sentences(text))
54
+ #return make_spans(text,results)
55
+
56
+ ##Company Extraction
57
+ #ner=pipeline('ner',model='Jean-Baptiste/camembert-ner-with-dates',tokenizer='Jean-Baptiste/camembert-ner-with-dates', aggregation_strategy="simple")
58
+ #def fin_ner(text):
59
+ #replaced_spans = ner(text)
60
+ # return replaced_spans
61
+
62
+
63
+
64
+
65
+ def load_questions():
66
+ questions = []
67
+ with open('questions.txt') as f:
68
+ questions = f.readlines()
69
+ return questions
70
+
71
+
72
+ def load_questions_short():
73
+ questions_short = []
74
+ with open('questionshort.txt') as f:
75
+ questions_short = f.readlines()
76
+ return questions_short
77
+
78
+
79
+ st.cache(show_spinner=False, persist=True)
80
+
81
+
82
+ questions = load_questions()
83
+ questions_short = load_questions_short()
84
+
85
+ ### DEFINE SIDEBAR
86
+ st.sidebar.title("Interactive Contract Analysis")
87
+
88
+ st.sidebar.header('CONTRACT UPLOAD')
89
+
90
+ # upload contract
91
+ user_upload = st.sidebar.file_uploader('Please upload your contract', type=['txt'],
92
+ accept_multiple_files=False)
93
+
94
+
95
+ # process upload
96
+ if user_upload is not None:
97
+ print(user_upload.name, user_upload.type)
98
+ extension = user_upload.name.split('.')[-1].lower()
99
+ if extension == 'txt':
100
+ print('text file uploaded')
101
+ # To convert to a string based IO:
102
+ stringio = StringIO(user_upload.getvalue().decode("utf-8"))
103
+
104
+ # To read file as string:
105
+ contract_data = stringio.read()
106
+ else:
107
+ st.warning('Unknown uploaded file type, please try again')
108
+
109
+ results_drop = ['1', '2', '3']
110
+ number_results = st.sidebar.selectbox('Select number of results', results_drop)
111
+
112
+ ### DEFINE MAIN PAGE
113
+ st.header("Legal Contract Review Demo")
114
+ paragraph = st.text_area(label="Contract", value=contract_data, height=300)
115
+
116
+ questions_drop = questions_short
117
+ question_short = st.selectbox('Choose one of the 41 queries from the CUAD dataset:', questions_drop)
118
+ idxq = questions_drop.index(question_short)
119
+ question = questions[idxq]
120
+
121
+
122
+ raw_answer=""
123
+ if st.button('Analyze'):
124
+ if (not len(paragraph)==0) and not (len(question)==0):
125
+ print('getting predictions')
126
+ with st.spinner(text='Analysis in progress...'):
127
+ predictions = run_prediction([question], paragraph, 'marshmellow77/roberta-base-cuad',
128
+ n_best_size=5)
129
+ answer = ""
130
+ if predictions['0'] == "":
131
+ answer = 'No answer found in document'
132
+ else:
133
+ # if number_results == '1':
134
+ # answer = f"Answer: {predictions['0']}"
135
+ # # st.text_area(label="Answer", value=f"{answer}")
136
+ # else:
137
+ answer = ""
138
+ with open("nbest.json") as jf:
139
+ data = json.load(jf)
140
+ for i in range(int(number_results)):
141
+ raw_answer=data['0'][i]['text']
142
+ answer += f"Answer {i+1}: {data['0'][i]['text']} -- \n"
143
+ answer += f"Probability: {round(data['0'][i]['probability']*100,1)}%\n\n"
144
+ st.success(answer)
145
+
146
+ else:
147
+ st.write("Unable to call model, please select question and contract")
148
+
149
+ if st.button('Check Sustainability'):
150
+ if(raw_answer==""):
151
+ st.write("Unable to call model, please select question and contract")
152
+ else:
153
+ st.write(get_sustainability(raw_answer))
154
+ if st.button('Summarize'):
155
+ if(raw_answer==""):
156
+ st.write("Unable to call model, please select question and contract")
157
+ else:
158
+ st.write(summarize_text(raw_answer))
159
+
160
+ if st.button('NER'):
161
+ if(raw_answer==""):
162
+ st.write("Unable to call model, please select question and contract")
163
+ else:
164
+ doc = nlp(raw_answer)
165
+ st.write(displacy.render(doc, style="ent"))