Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import os | |
import joblib | |
import plotly.graph_objects as go | |
from huggingface_hub import hf_hub_download | |
import streamlit as st | |
import streamlit.components.v1 as components | |
import numpy as np | |
from lime.lime_text import LimeTextExplainer | |
from lime import lime_text | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
REPO_ID = "simplexico/cuad-sklearn-clause-classifier" | |
FILENAME = "CUAD-clause-classifier.pkl" | |
EXAMPLE_TEXT = """This Agreement and any dispute or claim arising out of or in connection with it | |
or its subject matter or formation (including non-contractual disputes or claims) shall be | |
governed by and construed in accordance with the law of England.""" | |
## Layout stuff | |
st.set_page_config( | |
page_title="Label Clause Demo", | |
page_icon="๐ท", | |
layout="wide", | |
initial_sidebar_state="expanded", | |
menu_items={ | |
'Get Help': 'mailto:hello@simplexico.ai', | |
'Report a bug': None, | |
'About': "## This a demo showcasing different Legal AI Actions" | |
} | |
) | |
st.title('๐ท Label Clause Demo') | |
st.write(""" | |
This demo shows how AI can be used to label text. | |
We've trained an AI model to label a clause by its clause type. | |
""") | |
st.write("**๐ Enter a clause on the left** and hit the button **Label Clause** to see the demo in action") | |
def load_model(): | |
model = joblib.load( | |
hf_hub_download(repo_id=REPO_ID, filename=FILENAME, token=HF_TOKEN) | |
) | |
return model | |
def get_prediction_prob(text): | |
y_pred = model.predict([text])[0] | |
y_probs = model.predict_proba([text])[0] | |
return y_pred, y_probs | |
text = st.sidebar.text_area(label='Enter Clause Text', value=EXAMPLE_TEXT, height=250) | |
button = st.sidebar.button('**Label Clause**', type='primary', use_container_width=True) | |
with st.spinner('โ๏ธ Loading model...'): | |
model = load_model() | |
classes = [s.upper() for s in model.classes_] | |
if button: | |
with st.spinner('โ๏ธ Processing Clause...'): | |
y_pred, y_probs = get_prediction_prob(text) | |
explainer = LimeTextExplainer(class_names=[cls[:9] + 'โฆ' for cls in model.classes_]) | |
exp = explainer.explain_instance(text, | |
model.predict_proba, | |
num_features=10, | |
top_labels=1) | |
col1, col2 = st.columns(2) | |
with col1: | |
st.header('๐ค Prediction Results') | |
st.write( | |
f"The model predicts that this is a **{y_pred}** clause with **{y_probs.max() * 100:.2f}%** confidence.") | |
fig = go.Figure(go.Bar( | |
x=y_probs * 100, | |
y=model.classes_, | |
orientation='h')) | |
fig.update_layout( | |
title="Model Confidence", | |
xaxis_title="Confidence (%)", | |
yaxis_title="Clause Type", | |
) | |
st.plotly_chart(fig, use_container_width=True) | |
with col2: | |
st.header('๐ฎ Prediction Explainability') | |
st.write( | |
'We can perform an analysis to work out what terms in the clause were most important in deciding the predicted clause type:') | |
components.html(exp.as_html(predict_proba=False), height=800) | |