muhtasham commited on
Commit
13cf51e
1 Parent(s): 6ccc4b6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +154 -0
app.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import (
2
+ AutoConfig,
3
+ AutoModelForQuestionAnswering,
4
+ AutoTokenizer,
5
+ squad_convert_examples_to_features
6
+ )
7
+
8
+ from transformers.data.processors.squad import SquadResult, SquadV2Processor, SquadExample
9
+ from transformers.data.metrics.squad_metrics import compute_predictions_logits
10
+ import streamlit as st
11
+ import gradio as gr
12
+ import json
13
+ import torch
14
+ import time
15
+ from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
16
+
17
+ model_checkpoint = "akdeniz27/roberta-base-cuad"
18
+ st.sidebar.write("Model: akdeniz27/roberta-base-cuad")
19
+ st.sidebar.write("Project: https://www.atticusprojectai.org/cuad")
20
+ st.sidebar.write("Git Hub: https://github.com/TheAtticusProject/cuad")
21
+ st.sidebar.write("CUAD Dataset: https://huggingface.co/datasets/cuad")
22
+
23
+ @st.cache(allow_output_mutation=True)
24
+ def load_model():
25
+ model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)
26
+ tokenizer = AutoTokenizer.from_pretrained(model_checkpoint , use_fast=False)
27
+ return model, tokenizer
28
+
29
+ @st.cache(allow_output_mutation=True)
30
+ def load_questions():
31
+ with open('test.json') as json_file:
32
+ data = json.load(json_file)
33
+ questions = []
34
+ for i, q in enumerate(data['data'][0]['paragraphs'][0]['qas']):
35
+ question = data['data'][0]['paragraphs'][0]['qas'][i]['question']
36
+ questions.append(question)
37
+ return questions
38
+
39
+ @st.cache(allow_output_mutation=True)
40
+ def load_contracts():
41
+ with open('test.json') as json_file:
42
+ data = json.load(json_file)
43
+ contracts = []
44
+ for i, q in enumerate(data['data']):
45
+ contract = ' '.join(data['data'][i]['paragraphs'][0]['context'].split())
46
+ contracts.append(contract)
47
+ return contracts
48
+
49
+ model, tokenizer = load_model()
50
+ questions = load_questions()
51
+ contracts = load_contracts()
52
+ contract = contracts[0]
53
+
54
+ st.header("Contract Understanding Atticus Dataset (CUAD) Demo")
55
+ st.write("Based on https://github.com/marshmellow77/cuad-demo")
56
+
57
+ selected_question = st.selectbox('Choose one of the 41 queries from the CUAD dataset:', questions)
58
+ question_set = [questions[0], selected_question]
59
+ contract_type = st.radio("Select Contract", ("Sample Contract", "New Contract"))
60
+
61
+ if contract_type == "Sample Contract":
62
+ sample_contract_num = st.slider("Select Sample Contract #")
63
+ contract = contracts[sample_contract_num]
64
+ with st.expander(f"Sample Contract #{sample_contract_num}"):
65
+ st.write(contract)
66
+ else:
67
+ contract = st.text_area("Input New Contract", "", height=256)
68
+ Run_Button = st.button("Run", key=None)
69
+ if Run_Button == True and not len(contract)==0 and not len(question_set)==0:
70
+ predictions = run_prediction(question_set, contract, 'akdeniz27/roberta-base-cuad')
71
+
72
+ for i, p in enumerate(predictions):
73
+ if i != 0: st.write(f"Question: {question_set[int(p)]}\n\nAnswer: {predictions[p]}\n\n")
74
+
75
+
76
+ def run_prediction(question_texts, context_text, model_path):
77
+ max_seq_length = 512
78
+ doc_stride = 256
79
+ n_best_size = 1
80
+ max_query_length = 64
81
+ max_answer_length = 512
82
+ do_lower_case = False
83
+ null_score_diff_threshold = 0.0
84
+
85
+ def to_list(tensor):
86
+ return tensor.detach().cpu().tolist()
87
+ config_class, model_class, tokenizer_class = (
88
+ AutoConfig, AutoModelForQuestionAnswering, AutoTokenizer)
89
+ config = config_class.from_pretrained(model_path)
90
+ tokenizer = tokenizer_class.from_pretrained(
91
+ model_path, do_lower_case=True, use_fast=False)
92
+ model = model_class.from_pretrained(model_path, config=config)
93
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
94
+ model.to(device)
95
+ processor = SquadV2Processor()
96
+ examples = []
97
+ for i, question_text in enumerate(question_texts):
98
+ example = SquadExample(
99
+ qas_id=str(i),
100
+ question_text=question_text,
101
+ context_text=context_text,
102
+ answer_text=None,
103
+ start_position_character=None,
104
+ title="Predict",
105
+ answers=None,
106
+ )
107
+ examples.append(example)
108
+ features, dataset = squad_convert_examples_to_features(
109
+ examples=examples,
110
+ tokenizer=tokenizer,
111
+ max_seq_length=max_seq_length,
112
+ doc_stride=doc_stride,
113
+ max_query_length=max_query_length,
114
+ is_training=False,
115
+ return_dataset="pt",
116
+ threads=1,
117
+ )
118
+ eval_sampler = SequentialSampler(dataset)
119
+ eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=10)
120
+ all_results = []
121
+ for batch in eval_dataloader:
122
+ model.eval()
123
+ batch = tuple(t.to(device) for t in batch)
124
+ with torch.no_grad():
125
+ inputs = {
126
+ "input_ids": batch[0],
127
+ "attention_mask": batch[1],
128
+ "token_type_ids": batch[2],
129
+ }
130
+ example_indices = batch[3]
131
+ outputs = model(**inputs)
132
+ for i, example_index in enumerate(example_indices):
133
+ eval_feature = features[example_index.item()]
134
+ unique_id = int(eval_feature.unique_id)
135
+ output = [to_list(output[i]) for output in outputs.to_tuple()]
136
+ start_logits, end_logits = output
137
+ result = SquadResult(unique_id, start_logits, end_logits)
138
+ all_results.append(result)
139
+ final_predictions = compute_predictions_logits(
140
+ all_examples=examples,
141
+ all_features=features,
142
+ all_results=all_results,
143
+ n_best_size=n_best_size,
144
+ max_answer_length=max_answer_length,
145
+ do_lower_case=do_lower_case,
146
+ output_prediction_file=None,
147
+ output_nbest_file=None,
148
+ output_null_log_odds_file=None,
149
+ verbose_logging=False,
150
+ version_2_with_negative=True,
151
+ null_score_diff_threshold=null_score_diff_threshold,
152
+ tokenizer=tokenizer
153
+ )
154
+ return final_predictions