Rainsilves commited on
Commit
bd550e0
1 Parent(s): 6636abe

added significant performance improvements

Browse files
Files changed (1) hide show
  1. interpretable_text_clustering.py +15 -6
interpretable_text_clustering.py CHANGED
@@ -48,10 +48,14 @@ if task == "Classification":
48
 
49
  form.form_submit_button("Submit")
50
 
 
 
 
 
 
 
51
 
52
- dataset = load_dataset(path = dataset_name, name = dataset_name_2, streaming=True)
53
- dataset_head = dataset[split_name].take(number_of_records)
54
- df = pd.DataFrame.from_dict(dataset_head)
55
 
56
  display_full_df = form.checkbox("Display the full dataset dataframe?")
57
  display_X_df = form.checkbox("Display the training data?", value = True)
@@ -73,13 +77,17 @@ model_name = form.text_area("Enter the name of the pre-trained model from senten
73
  form.caption("This will download a new model, so it may take awhile or even break if the model is too large")
74
  form.caption("See the list of pre-trained models that are available here! https://www.sbert.net/docs/pretrained_models.html")
75
 
76
- model = SentenceTransformer(model_name)
77
 
78
 
79
 
80
- #embeddings = model.encode(sentences, convert_to_numpy = True)
 
 
 
 
81
 
82
- embedder = FunctionTransformer(lambda item:model.encode(item, convert_to_numpy=True, show_progress_bar=False))
83
 
84
 
85
 
@@ -114,6 +122,7 @@ if task == "Clustering":
114
  ('vect', embedder),
115
  ('cluster', KMeans(n_clusters = n_clusters, n_init = n_init, max_iter = max_iter)),
116
  ])
 
117
 
118
  if task == "Classification":
119
  text_clf.fit(df[column_name], df[labels_column_name])
 
48
 
49
  form.form_submit_button("Submit")
50
 
51
+ @st.cache
52
+ def load_and_process_data(path, name, streaming, split_name, number_of_records):
53
+ dataset = load_dataset(path = path, name = name, streaming=streaming)
54
+ dataset_head = dataset[split_name].take(number_of_records)
55
+ df = pd.DataFrame.from_dict(dataset_head)
56
+ return df
57
 
58
+ df = load_and_process_data(dataset_name, dataset_name_2, True, split_name, number_of_records)
 
 
59
 
60
  display_full_df = form.checkbox("Display the full dataset dataframe?")
61
  display_X_df = form.checkbox("Display the training data?", value = True)
 
77
  form.caption("This will download a new model, so it may take awhile or even break if the model is too large")
78
  form.caption("See the list of pre-trained models that are available here! https://www.sbert.net/docs/pretrained_models.html")
79
 
80
+ #embeddings = model.encode(sentences, convert_to_numpy = True)
81
 
82
 
83
 
84
+ @st.cache
85
+ def load_model_and_return_embedder(model_name):
86
+ model = SentenceTransformer(model_name)
87
+ embedder = FunctionTransformer(lambda item:model.encode(item, convert_to_numpy=True, show_progress_bar=False))
88
+ return embedder
89
 
90
+ embedder = load_model_and_return_embedder(model_name=model_name)
91
 
92
 
93
 
 
122
  ('vect', embedder),
123
  ('cluster', KMeans(n_clusters = n_clusters, n_init = n_init, max_iter = max_iter)),
124
  ])
125
+
126
 
127
  if task == "Classification":
128
  text_clf.fit(df[column_name], df[labels_column_name])