muhtasham commited on
Commit
47dd29e
1 Parent(s): 13cf51e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -55
app.py CHANGED
@@ -21,58 +21,6 @@ st.sidebar.write("Git Hub: https://github.com/TheAtticusProject/cuad")
21
  st.sidebar.write("CUAD Dataset: https://huggingface.co/datasets/cuad")
22
 
23
  @st.cache(allow_output_mutation=True)
24
- def load_model():
25
- model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)
26
- tokenizer = AutoTokenizer.from_pretrained(model_checkpoint , use_fast=False)
27
- return model, tokenizer
28
-
29
- @st.cache(allow_output_mutation=True)
30
- def load_questions():
31
- with open('test.json') as json_file:
32
- data = json.load(json_file)
33
- questions = []
34
- for i, q in enumerate(data['data'][0]['paragraphs'][0]['qas']):
35
- question = data['data'][0]['paragraphs'][0]['qas'][i]['question']
36
- questions.append(question)
37
- return questions
38
-
39
- @st.cache(allow_output_mutation=True)
40
- def load_contracts():
41
- with open('test.json') as json_file:
42
- data = json.load(json_file)
43
- contracts = []
44
- for i, q in enumerate(data['data']):
45
- contract = ' '.join(data['data'][i]['paragraphs'][0]['context'].split())
46
- contracts.append(contract)
47
- return contracts
48
-
49
- model, tokenizer = load_model()
50
- questions = load_questions()
51
- contracts = load_contracts()
52
- contract = contracts[0]
53
-
54
- st.header("Contract Understanding Atticus Dataset (CUAD) Demo")
55
- st.write("Based on https://github.com/marshmellow77/cuad-demo")
56
-
57
- selected_question = st.selectbox('Choose one of the 41 queries from the CUAD dataset:', questions)
58
- question_set = [questions[0], selected_question]
59
- contract_type = st.radio("Select Contract", ("Sample Contract", "New Contract"))
60
-
61
- if contract_type == "Sample Contract":
62
- sample_contract_num = st.slider("Select Sample Contract #")
63
- contract = contracts[sample_contract_num]
64
- with st.expander(f"Sample Contract #{sample_contract_num}"):
65
- st.write(contract)
66
- else:
67
- contract = st.text_area("Input New Contract", "", height=256)
68
- Run_Button = st.button("Run", key=None)
69
- if Run_Button == True and not len(contract)==0 and not len(question_set)==0:
70
- predictions = run_prediction(question_set, contract, 'akdeniz27/roberta-base-cuad')
71
-
72
- for i, p in enumerate(predictions):
73
- if i != 0: st.write(f"Question: {question_set[int(p)]}\n\nAnswer: {predictions[p]}\n\n")
74
-
75
-
76
  def run_prediction(question_texts, context_text, model_path):
77
  max_seq_length = 512
78
  doc_stride = 256
@@ -81,8 +29,7 @@ def run_prediction(question_texts, context_text, model_path):
81
  max_answer_length = 512
82
  do_lower_case = False
83
  null_score_diff_threshold = 0.0
84
-
85
- def to_list(tensor):
86
  return tensor.detach().cpu().tolist()
87
  config_class, model_class, tokenizer_class = (
88
  AutoConfig, AutoModelForQuestionAnswering, AutoTokenizer)
@@ -151,4 +98,56 @@ def run_prediction(question_texts, context_text, model_path):
151
  null_score_diff_threshold=null_score_diff_threshold,
152
  tokenizer=tokenizer
153
  )
154
- return final_predictions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  st.sidebar.write("CUAD Dataset: https://huggingface.co/datasets/cuad")
22
 
23
  @st.cache(allow_output_mutation=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def run_prediction(question_texts, context_text, model_path):
25
  max_seq_length = 512
26
  doc_stride = 256
 
29
  max_answer_length = 512
30
  do_lower_case = False
31
  null_score_diff_threshold = 0.0
32
+ def to_list(tensor):
 
33
  return tensor.detach().cpu().tolist()
34
  config_class, model_class, tokenizer_class = (
35
  AutoConfig, AutoModelForQuestionAnswering, AutoTokenizer)
 
98
  null_score_diff_threshold=null_score_diff_threshold,
99
  tokenizer=tokenizer
100
  )
101
+ return final_predictions
102
+
103
+ @st.cache(allow_output_mutation=True)
104
+ def load_model():
105
+ model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)
106
+ tokenizer = AutoTokenizer.from_pretrained(model_checkpoint , use_fast=False)
107
+ return model, tokenizer
108
+
109
+ @st.cache(allow_output_mutation=True)
110
+ def load_questions():
111
+ with open('test.json') as json_file:
112
+ data = json.load(json_file)
113
+ questions = []
114
+ for i, q in enumerate(data['data'][0]['paragraphs'][0]['qas']):
115
+ question = data['data'][0]['paragraphs'][0]['qas'][i]['question']
116
+ questions.append(question)
117
+ return questions
118
+
119
+ @st.cache(allow_output_mutation=True)
120
+ def load_contracts():
121
+ with open('test.json') as json_file:
122
+ data = json.load(json_file)
123
+ contracts = []
124
+ for i, q in enumerate(data['data']):
125
+ contract = ' '.join(data['data'][i]['paragraphs'][0]['context'].split())
126
+ contracts.append(contract)
127
+ return contracts
128
+
129
+ model, tokenizer = load_model()
130
+ questions = load_questions()
131
+ contracts = load_contracts()
132
+ contract = contracts[0]
133
+
134
+ st.header("Contract Understanding Atticus Dataset (CUAD) Demo")
135
+ st.write("Based on https://github.com/marshmellow77/cuad-demo")
136
+
137
+ selected_question = st.selectbox('Choose one of the 41 queries from the CUAD dataset:', questions)
138
+ question_set = [questions[0], selected_question]
139
+ contract_type = st.radio("Select Contract", ("Sample Contract", "New Contract"))
140
+
141
+ if contract_type == "Sample Contract":
142
+ sample_contract_num = st.slider("Select Sample Contract #")
143
+ contract = contracts[sample_contract_num]
144
+ with st.expander(f"Sample Contract #{sample_contract_num}"):
145
+ st.write(contract)
146
+ else:
147
+ contract = st.text_area("Input New Contract", "", height=256)
148
+ Run_Button = st.button("Run", key=None)
149
+ if Run_Button == True and not len(contract)==0 and not len(question_set)==0:
150
+ predictions = run_prediction(question_set, contract, 'akdeniz27/roberta-base-cuad')
151
+
152
+ for i, p in enumerate(predictions):
153
+ if i != 0: st.write(f"Question: {question_set[int(p)]}\n\nAnswer: {predictions[p]}\n\n")