asoria HF staff commited on
Commit
2269797
·
1 Parent(s): e65c78c

Disable ZeroGPU

Browse files
Files changed (1) hide show
  1. app.py +7 -14
app.py CHANGED
@@ -7,7 +7,7 @@ from transformers import (
7
  )
8
 
9
  # These imports at the end because of torch/datamapplot issue in Zero GPU
10
- import spaces
11
  import gradio as gr
12
 
13
  import logging
@@ -93,8 +93,6 @@ representation_model = TextGeneration(generator, prompt=REPRESENTATION_PROMPT)
93
 
94
  vectorizer_model = CountVectorizer(stop_words="english")
95
 
96
- global_topic_model = None
97
-
98
 
99
  def get_split_rows(dataset, config, split):
100
  config_size = session.get(
@@ -131,7 +129,7 @@ def get_docs_from_parquet(parquet_urls, column, offset, limit):
131
  return df[column].tolist()
132
 
133
 
134
- @spaces.GPU
135
  def calculate_embeddings(docs):
136
  return sentence_model.encode(docs, show_progress_bar=True, batch_size=32)
137
 
@@ -142,10 +140,8 @@ def calculate_n_neighbors_and_components(n_rows):
142
  return n_neighbors, n_components
143
 
144
 
145
- @spaces.GPU
146
  def fit_model(docs, embeddings, n_neighbors, n_components):
147
- global global_topic_model
148
-
149
  umap_model = UMAP(
150
  n_neighbors=n_neighbors,
151
  n_components=n_components,
@@ -180,9 +176,7 @@ def fit_model(docs, embeddings, n_neighbors, n_components):
180
  new_model.fit(docs, embeddings)
181
  logging.info("End fitting new model")
182
 
183
- global_topic_model = new_model
184
-
185
- logging.info("Global model updated")
186
 
187
 
188
  def _push_to_hub(
@@ -207,7 +201,6 @@ def _push_to_hub(
207
 
208
 
209
  def generate_topics(dataset, config, split, column, nested_column, plot_type):
210
- global global_topic_model
211
  logging.info(
212
  f"Generating topics for {dataset} with config {config} {split} {column} {nested_column}"
213
  )
@@ -257,12 +250,12 @@ def generate_topics(dataset, config, split, column, nested_column, plot_type):
257
  )
258
 
259
  embeddings = calculate_embeddings(docs)
260
- fit_model(docs, embeddings, n_neighbors, n_components)
261
 
262
  if base_model is None:
263
- base_model = global_topic_model
264
  else:
265
- updated_model = BERTopic.merge_models([base_model, global_topic_model])
266
  nr_new_topics = len(set(updated_model.topics_)) - len(
267
  set(base_model.topics_)
268
  )
 
7
  )
8
 
9
  # These imports at the end because of torch/datamapplot issue in Zero GPU
10
+ # import spaces
11
  import gradio as gr
12
 
13
  import logging
 
93
 
94
  vectorizer_model = CountVectorizer(stop_words="english")
95
 
 
 
96
 
97
  def get_split_rows(dataset, config, split):
98
  config_size = session.get(
 
129
  return df[column].tolist()
130
 
131
 
132
+ # @spaces.GPU
133
  def calculate_embeddings(docs):
134
  return sentence_model.encode(docs, show_progress_bar=True, batch_size=32)
135
 
 
140
  return n_neighbors, n_components
141
 
142
 
143
+ # @spaces.GPU
144
  def fit_model(docs, embeddings, n_neighbors, n_components):
 
 
145
  umap_model = UMAP(
146
  n_neighbors=n_neighbors,
147
  n_components=n_components,
 
176
  new_model.fit(docs, embeddings)
177
  logging.info("End fitting new model")
178
 
179
+ return new_model
 
 
180
 
181
 
182
  def _push_to_hub(
 
201
 
202
 
203
  def generate_topics(dataset, config, split, column, nested_column, plot_type):
 
204
  logging.info(
205
  f"Generating topics for {dataset} with config {config} {split} {column} {nested_column}"
206
  )
 
250
  )
251
 
252
  embeddings = calculate_embeddings(docs)
253
+ new_model = fit_model(docs, embeddings, n_neighbors, n_components)
254
 
255
  if base_model is None:
256
+ base_model = new_model
257
  else:
258
+ updated_model = BERTopic.merge_models([base_model, new_model])
259
  nr_new_topics = len(set(updated_model.topics_)) - len(
260
  set(base_model.topics_)
261
  )