asoria HF staff commited on
Commit
c3813c7
·
1 Parent(s): 560300f

Fix for small datasets and custom topics

Browse files
Files changed (1) hide show
  1. app.py +9 -7
app.py CHANGED
@@ -152,7 +152,7 @@ def generate_topics(dataset, config, split, column, nested_column):
152
  base_model = None
153
  all_docs = []
154
  reduced_embeddings_list = []
155
-
156
  while offset < limit:
157
  docs = get_docs_from_parquet(parquet_urls, column, offset, chunk_size)
158
  if not docs:
@@ -164,11 +164,13 @@ def generate_topics(dataset, config, split, column, nested_column):
164
 
165
  embeddings = calculate_embeddings(docs)
166
  base_model, _ = fit_model(base_model, docs, embeddings)
167
- llama2_labels = [
168
- label[0][0].split("\n")[0]
169
- for label in base_model.get_topics(full=True)["Llama2"].values()
170
- ]
171
- base_model.set_topic_labels(llama2_labels)
 
 
172
 
173
  reduced_embeddings = reduce_umap_model.fit_transform(embeddings)
174
  reduced_embeddings_list.append(reduced_embeddings)
@@ -189,7 +191,7 @@ def generate_topics(dataset, config, split, column, nested_column):
189
  offset += chunk_size
190
 
191
  logging.info("Finished processing all data")
192
- return base_model.get_topic_info(), base_model.visualize_topics()
193
 
194
 
195
  with gr.Blocks() as demo:
 
152
  base_model = None
153
  all_docs = []
154
  reduced_embeddings_list = []
155
+ topics_info, topic_plot = None, None
156
  while offset < limit:
157
  docs = get_docs_from_parquet(parquet_urls, column, offset, chunk_size)
158
  if not docs:
 
164
 
165
  embeddings = calculate_embeddings(docs)
166
  base_model, _ = fit_model(base_model, docs, embeddings)
167
+
168
+ repr_model_topics = {
169
+ key: label[0][0].split("\n")[0]
170
+ for key, label in base_model.get_topics(full=True)["Llama2"].items()
171
+ }
172
+
173
+ base_model.set_topic_labels(repr_model_topics)
174
 
175
  reduced_embeddings = reduce_umap_model.fit_transform(embeddings)
176
  reduced_embeddings_list.append(reduced_embeddings)
 
191
  offset += chunk_size
192
 
193
  logging.info("Finished processing all data")
194
+ return topics_info, topic_plot
195
 
196
 
197
  with gr.Blocks() as demo: