JMuscatello commited on
Commit
50b327c
β€’
1 Parent(s): 6d70e63

Change caching structure

Browse files
Files changed (1) hide show
  1. pages/5_πŸ—‚_Organise_Demo.py +48 -32
pages/5_πŸ—‚_Organise_Demo.py CHANGED
@@ -91,53 +91,69 @@ def generate_plot(X, y, filenames):
91
 
92
  return fig
93
 
94
- uploaded_files = st.sidebar.file_uploader("Select contracts to organise ", accept_multiple_files=True)
 
 
 
95
 
96
- button = st.sidebar.button('Organise Contracts', type='primary', use_container_width=True)
97
 
98
- with st.container():
99
- with st.spinner('βš™οΈ Loading model...'):
100
- cuad_tfidf_umap_kmeans = load_model()
101
- cuad_df = load_dataset()
102
 
103
- X = [text[:500] for text in cuad_df['text'].to_list()]
104
- filenames = cuad_df['filename'].to_list()
105
 
106
- X_transform, y = get_transform_and_predictions(cuad_tfidf_umap_kmeans, X)
 
 
 
107
 
108
- fig = generate_plot(X_transform, y, filenames)
 
109
 
110
- figure = st.plotly_chart(fig, use_container_width=True)
111
 
112
- if button:
113
- figure.empty()
114
 
115
- with st.spinner('βš™οΈ Training model...'):
 
 
 
116
 
117
- if not uploaded_files or not len(uploaded_files) > 1:
118
- st.write(
119
- "Please add at least two contracts"
120
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  else:
122
- if len(uploaded_files) < 10:
123
- n_clusters = 3
124
- else:
125
- n_clusters = 8
126
-
127
- X_train = [uploaded_file.read()[:500] for uploaded_file in uploaded_files]
128
- filenames = [uploaded_file.name for uploaded_file in uploaded_files]
129
 
130
- tfidf_umap_kmeans = deepcopy(cuad_tfidf_umap_kmeans)
131
- tfidf_umap_kmeans.set_params(kmeans__n_clusters=4)
132
- tfidf_umap_kmeans.fit(X_train)
133
 
134
- X_transform, y = get_transform_and_predictions(cuad_tfidf_umap_kmeans, X_train)
135
 
136
- fig = generate_plot(X_transform, y, filenames)
137
 
138
- st.write("**Your organised contracts:**")
139
 
140
- st.plotly_chart(fig, use_container_width=True)
141
 
142
  add_email_signup_form()
143
 
 
91
 
92
  return fig
93
 
94
+ @st.cache(allow_output_mutation=True)
95
+ def prepare_figure(model, df):
96
+ X = [text[:500] for text in df['text'].to_list()]
97
+ filenames = df['filename'].to_list()
98
 
99
+ X_transform, y = get_transform_and_predictions(model, X)
100
 
101
+ fig = generate_plot(X_transform, y, filenames)
 
 
 
102
 
103
+ return fig
 
104
 
105
+ @st.cache()
106
+ def prepare_page():
107
+ model = load_model()
108
+ df = load_dataset()
109
 
110
+ X = [text[:500] for text in df['text'].to_list()]
111
+ filenames = df['filename'].to_list()
112
 
113
+ X_transform, y = get_transform_and_predictions(model, X)
114
 
115
+ fig = prepare_figure(model, df)
 
116
 
117
+ return fig, model
118
+
119
+
120
+ uploaded_files = st.sidebar.file_uploader("Select contracts to organise ", accept_multiple_files=True)
121
 
122
+ button = st.sidebar.button('Organise Contracts', type='primary', use_container_width=True)
123
+
124
+ with st.spinner('βš™οΈ Loading model...'):
125
+ fig, cuad_tfidf_umap_kmeans = prepare_page()
126
+ figure = st.plotly_chart(fig, use_container_width=True)
127
+
128
+ if button:
129
+ figure.empty()
130
+
131
+ with st.spinner('βš™οΈ Training model...'):
132
+
133
+ if not uploaded_files or not len(uploaded_files) > 2:
134
+ st.write(
135
+ "**Please add at least three contracts**"
136
+ )
137
+ else:
138
+ if len(uploaded_files) < 10:
139
+ n_clusters = 3
140
  else:
141
+ n_clusters = 8
142
+
143
+ X_train = [uploaded_file.read()[:500] for uploaded_file in uploaded_files]
144
+ filenames = [uploaded_file.name for uploaded_file in uploaded_files]
 
 
 
145
 
146
+ tfidf_umap_kmeans = deepcopy(cuad_tfidf_umap_kmeans)
147
+ tfidf_umap_kmeans.set_params(kmeans__n_clusters=n_clusters)
148
+ tfidf_umap_kmeans.fit(X_train)
149
 
150
+ X_transform, y = get_transform_and_predictions(cuad_tfidf_umap_kmeans, X_train)
151
 
152
+ fig = generate_plot(X_transform, y, filenames)
153
 
154
+ st.write("**Your organised contracts:**")
155
 
156
+ st.plotly_chart(fig, use_container_width=True)
157
 
158
  add_email_signup_form()
159