akdeniz27's picture
Update app.py
7cd7d9d
raw
history blame
2.75 kB
from transformers import AutoModelForQuestionAnswering, AutoTokenizer
import streamlit as st
import json
from predict import run_prediction
st.set_page_config(layout="wide")
model_list = ['akdeniz27/roberta-base-cuad',
'akdeniz27/roberta-large-cuad',
'akdeniz27/deberta-v2-xlarge-cuad']
st.sidebar.header("Select CUAD Model")
model_checkpoint = st.sidebar.radio("", model_list)
if model_checkpoint == "akdeniz27/deberta-v2-xlarge-cuad": import sentencepiece
st.sidebar.write("Project: https://www.atticusprojectai.org/cuad")
st.sidebar.write("Git Hub: https://github.com/TheAtticusProject/cuad")
st.sidebar.write("CUAD Dataset: https://huggingface.co/datasets/cuad")
st.sidebar.write("License: CC BY 4.0 https://creativecommons.org/licenses/by/4.0/")
@st.cache(allow_output_mutation=True)
def load_model():
model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint , use_fast=False)
return model, tokenizer
@st.cache(allow_output_mutation=True)
def load_questions():
with open('test.json') as json_file:
data = json.load(json_file)
questions = []
for i, q in enumerate(data['data'][0]['paragraphs'][0]['qas']):
question = data['data'][0]['paragraphs'][0]['qas'][i]['question']
questions.append(question)
return questions
@st.cache(allow_output_mutation=True)
def load_contracts():
with open('test.json') as json_file:
data = json.load(json_file)
contracts = []
for i, q in enumerate(data['data']):
contract = ' '.join(data['data'][i]['paragraphs'][0]['context'].split())
contracts.append(contract)
return contracts
model, tokenizer = load_model()
questions = load_questions()
contracts = load_contracts()
contract = contracts[0]
st.header("Contract Understanding Atticus Dataset (CUAD) Demo")
st.write("Based on https://github.com/marshmellow77/cuad-demo")
selected_question = st.selectbox('Choose one of the 41 queries from the CUAD dataset:', questions)
question_set = [questions[0], selected_question]
contract_type = st.radio("Select Contract", ("Sample Contract", "New Contract"))
if contract_type == "Sample Contract":
sample_contract_num = st.slider("Select Sample Contract #")
contract = contracts[sample_contract_num]
with st.expander(f"Sample Contract #{sample_contract_num}"):
st.write(contract)
else:
contract = st.text_area("Input New Contract", "", height=256)
Run_Button = st.button("Run", key=None)
if Run_Button == True and not len(contract)==0 and not len(question_set)==0:
predictions = run_prediction(question_set, contract, 'akdeniz27/roberta-base-cuad')
for i, p in enumerate(predictions):
if i != 0: st.write(f"Question: {question_set[int(p)]}\n\nAnswer: {predictions[p]}\n\n")