Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import os | |
import joblib | |
from copy import deepcopy | |
import pandas as pd | |
import plotly.express as px | |
from huggingface_hub import hf_hub_download, snapshot_download | |
import streamlit as st | |
import streamlit_analytics | |
from utils import add_logo_to_sidebar, add_footer, add_email_signup_form | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
MODEL_REPO_ID = "simplexico/cuad-sklearn-contract-clustering" | |
DATA_REPO_ID = "simplexico/cuad-top-ten" | |
MODEL_FILENAME = "cuad_tfidf_umap_kmeans.pkl" | |
DATA_FILENAME = "cuad_top_ten_popular_contract_types.json" | |
streamlit_analytics.start_tracking() | |
st.set_page_config( | |
page_title="Organise Demo", | |
page_icon="π", | |
layout="wide", | |
initial_sidebar_state="expanded", | |
menu_items={ | |
'Get Help': 'mailto:hello@simplexico.ai', | |
'Report a bug': None, | |
'About': "## This a demo showcasing different Legal AI Actions" | |
} | |
) | |
add_logo_to_sidebar() | |
st.sidebar.success("π Select a demo above.") | |
st.title('π Organise Demo') | |
st.write(""" | |
This demo shows how AI can be used to organise contracts. | |
We've trained a model to group contracts into similar types. | |
The plot below shows a sample set of contracts that have been automatically grouped together. | |
Each point in the plot represents how the model interprets a contract, the closer together a pair of points are, the more similar they appear to the model. | |
Similar documents are grouped by color. | |
\n**TIP:** Hover over each point to see the filename of the contract. Groups can be added or removed by clicking on the symbol in the plot legend. | |
""") | |
st.write("**π Upload your own contracts on the left (as .txt files)** and hit the button **Organise Data** to see how your own contracts can be grouped together") | |
def load_model(): | |
model = joblib.load( | |
hf_hub_download(repo_id=MODEL_REPO_ID, filename=MODEL_FILENAME, token=HF_TOKEN) | |
) | |
return model | |
def load_dataset(): | |
snapshot_download(repo_id=DATA_REPO_ID, token=HF_TOKEN, local_dir='./', repo_type='dataset') | |
df = pd.read_json(DATA_FILENAME) | |
return df | |
def get_transform_and_predictions(model, X): | |
y = model.predict(X) | |
X_transform = model[:2].transform(X) | |
return X_transform, y | |
def generate_plot(X, y, filenames): | |
fig = px.scatter_3d( | |
x=X[:,0], | |
y=X[:,1], | |
z=X[:,2], | |
color=[str(y_i) for y_i in y], hover_name=filenames) | |
fig.update_traces( | |
marker_size=8, | |
marker_line=dict(width=2), | |
selector=dict(mode='markers') | |
) | |
fig.update_layout( | |
legend=dict( | |
title='grouping', | |
yanchor="top", | |
y=0.99, | |
xanchor="left", | |
x=0.01 | |
), | |
width=1100, | |
height=900 | |
) | |
return fig | |
uploaded_files = st.sidebar.file_uploader("Select contracts to organise ", accept_multiple_files=True) | |
button = st.sidebar.button('Organise Contracts', type='primary', use_container_width=True) | |
with st.container(): | |
with st.spinner('βοΈ Loading model...'): | |
cuad_tfidf_umap_kmeans = load_model() | |
cuad_df = load_dataset() | |
X = [text[:500] for text in cuad_df['text'].to_list()] | |
filenames = cuad_df['filename'].to_list() | |
X_transform, y = get_transform_and_predictions(cuad_tfidf_umap_kmeans, X) | |
fig = generate_plot(X_transform, y, filenames) | |
figure = st.plotly_chart(fig, use_container_width=True) | |
if button: | |
figure.empty() | |
with st.spinner('βοΈ Training model...'): | |
if not uploaded_files or not len(uploaded_files) > 1: | |
st.write( | |
"Please add at least two contracts" | |
) | |
else: | |
if len(uploaded_files) < 10: | |
n_clusters = 3 | |
else: | |
n_clusters = 8 | |
X_train = [uploaded_file.read()[:500] for uploaded_file in uploaded_files] | |
filenames = [uploaded_file.name for uploaded_file in uploaded_files] | |
tfidf_umap_kmeans = deepcopy(cuad_tfidf_umap_kmeans) | |
tfidf_umap_kmeans.set_params(kmeans__n_clusters=4) | |
tfidf_umap_kmeans.fit(X_train) | |
X_transform, y = get_transform_and_predictions(cuad_tfidf_umap_kmeans, X_train) | |
fig = generate_plot(X_transform, y, filenames) | |
st.write("**Your organised contracts:**") | |
st.plotly_chart(fig, use_container_width=True) | |
add_email_signup_form() | |
add_footer() | |
streamlit_analytics.stop_tracking(unsafe_password=os.environ["ANALYTICS_PASSWORD"]) |