Rainsilves
commited on
Commit
•
bd550e0
1
Parent(s):
6636abe
added significant performance improvements
Browse files
interpretable_text_clustering.py
CHANGED
@@ -48,10 +48,14 @@ if task == "Classification":
|
|
48 |
|
49 |
form.form_submit_button("Submit")
|
50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
-
|
53 |
-
dataset_head = dataset[split_name].take(number_of_records)
|
54 |
-
df = pd.DataFrame.from_dict(dataset_head)
|
55 |
|
56 |
display_full_df = form.checkbox("Display the full dataset dataframe?")
|
57 |
display_X_df = form.checkbox("Display the training data?", value = True)
|
@@ -73,13 +77,17 @@ model_name = form.text_area("Enter the name of the pre-trained model from senten
|
|
73 |
form.caption("This will download a new model, so it may take awhile or even break if the model is too large")
|
74 |
form.caption("See the list of pre-trained models that are available here! https://www.sbert.net/docs/pretrained_models.html")
|
75 |
|
76 |
-
model =
|
77 |
|
78 |
|
79 |
|
80 |
-
|
|
|
|
|
|
|
|
|
81 |
|
82 |
-
embedder =
|
83 |
|
84 |
|
85 |
|
@@ -114,6 +122,7 @@ if task == "Clustering":
|
|
114 |
('vect', embedder),
|
115 |
('cluster', KMeans(n_clusters = n_clusters, n_init = n_init, max_iter = max_iter)),
|
116 |
])
|
|
|
117 |
|
118 |
if task == "Classification":
|
119 |
text_clf.fit(df[column_name], df[labels_column_name])
|
|
|
48 |
|
49 |
form.form_submit_button("Submit")
|
50 |
|
51 |
+
@st.cache
|
52 |
+
def load_and_process_data(path, name, streaming, split_name, number_of_records):
|
53 |
+
dataset = load_dataset(path = path, name = name, streaming=streaming)
|
54 |
+
dataset_head = dataset[split_name].take(number_of_records)
|
55 |
+
df = pd.DataFrame.from_dict(dataset_head)
|
56 |
+
return df
|
57 |
|
58 |
+
df = load_and_process_data(dataset_name, dataset_name_2, True, split_name, number_of_records)
|
|
|
|
|
59 |
|
60 |
display_full_df = form.checkbox("Display the full dataset dataframe?")
|
61 |
display_X_df = form.checkbox("Display the training data?", value = True)
|
|
|
77 |
form.caption("This will download a new model, so it may take awhile or even break if the model is too large")
|
78 |
form.caption("See the list of pre-trained models that are available here! https://www.sbert.net/docs/pretrained_models.html")
|
79 |
|
80 |
+
#embeddings = model.encode(sentences, convert_to_numpy = True)
|
81 |
|
82 |
|
83 |
|
84 |
+
@st.cache
|
85 |
+
def load_model_and_return_embedder(model_name):
|
86 |
+
model = SentenceTransformer(model_name)
|
87 |
+
embedder = FunctionTransformer(lambda item:model.encode(item, convert_to_numpy=True, show_progress_bar=False))
|
88 |
+
return embedder
|
89 |
|
90 |
+
embedder = load_model_and_return_embedder(model_name=model_name)
|
91 |
|
92 |
|
93 |
|
|
|
122 |
('vect', embedder),
|
123 |
('cluster', KMeans(n_clusters = n_clusters, n_init = n_init, max_iter = max_iter)),
|
124 |
])
|
125 |
+
|
126 |
|
127 |
if task == "Classification":
|
128 |
text_clf.fit(df[column_name], df[labels_column_name])
|