asoria HF staff commited on
Commit
560300f
·
1 Parent(s): 4996a19
Files changed (2) hide show
  1. app.py +22 -20
  2. prompts.py +5 -3
app.py CHANGED
@@ -7,7 +7,6 @@ from bertopic import BERTopic
7
  import gradio as gr
8
  from bertopic.representation import (
9
  KeyBERTInspired,
10
- MaximalMarginalRelevance,
11
  TextGeneration,
12
  )
13
  from umap import UMAP
@@ -19,8 +18,7 @@ from transformers import (
19
  AutoModelForCausalLM,
20
  pipeline,
21
  )
22
- from prompts import system_prompt, example_prompt, main_prompt
23
- from umap import UMAP
24
  from hdbscan import HDBSCAN
25
  from sklearn.feature_extraction.text import CountVectorizer
26
 
@@ -36,7 +34,6 @@ logging.basicConfig(
36
  session = requests.Session()
37
  sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
38
  keybert = KeyBERTInspired()
39
- mmr = MaximalMarginalRelevance(diversity=0.3)
40
  vectorizer_model = CountVectorizer(stop_words="english")
41
 
42
  model_id = "meta-llama/Llama-2-7b-chat-hf"
@@ -52,7 +49,6 @@ bnb_config = BitsAndBytesConfig(
52
 
53
  tokenizer = AutoTokenizer.from_pretrained(model_id)
54
 
55
- # Llama 2 Model
56
  model = AutoModelForCausalLM.from_pretrained(
57
  model_id,
58
  trust_remote_code=True,
@@ -68,13 +64,11 @@ generator = pipeline(
68
  max_new_tokens=500,
69
  repetition_penalty=1.1,
70
  )
71
- prompt = system_prompt + example_prompt + main_prompt
72
 
73
- llama2 = TextGeneration(generator, prompt=prompt)
74
  representation_model = {
75
  "KeyBERT": keybert,
76
  "Llama2": llama2,
77
- # "MMR": mmr,
78
  }
79
 
80
  umap_model = UMAP(
@@ -132,9 +126,9 @@ def fit_model(base_model, docs, embeddings):
132
  verbose=True,
133
  min_topic_size=15,
134
  )
135
- logging.info("Fitting new model")
136
  new_model.fit(docs, embeddings)
137
- logging.info("End fitting new model")
138
 
139
  if base_model is None:
140
  return new_model, new_model
@@ -157,35 +151,43 @@ def generate_topics(dataset, config, split, column, nested_column):
157
  offset = 0
158
  base_model = None
159
  all_docs = []
160
- all_reduced_embeddings = np.empty((0, 2))
161
- while True:
 
162
  docs = get_docs_from_parquet(parquet_urls, column, offset, chunk_size)
 
 
 
163
  logging.info(
164
- f"------------> New chunk data {offset=} {chunk_size=} with {len(docs)} docs"
165
  )
 
166
  embeddings = calculate_embeddings(docs)
167
- offset = offset + chunk_size
168
- if not docs or offset >= limit:
169
- break
170
  base_model, _ = fit_model(base_model, docs, embeddings)
171
  llama2_labels = [
172
  label[0][0].split("\n")[0]
173
  for label in base_model.get_topics(full=True)["Llama2"].values()
174
  ]
175
- logging.info(f"Topics: {llama2_labels}")
176
  base_model.set_topic_labels(llama2_labels)
177
 
178
  reduced_embeddings = reduce_umap_model.fit_transform(embeddings)
 
179
 
180
  all_docs.extend(docs)
181
- all_reduced_embeddings = np.vstack((all_reduced_embeddings, reduced_embeddings))
182
  topics_info = base_model.get_topic_info()
183
  topic_plot = base_model.visualize_documents(
184
- all_docs, reduced_embeddings=all_reduced_embeddings, custom_labels=True
 
 
185
  )
186
- logging.info(f"Topics for merged model: {base_model.topic_labels_}")
 
 
187
  yield topics_info, topic_plot
188
 
 
 
189
  logging.info("Finished processing all data")
190
  return base_model.get_topic_info(), base_model.visualize_topics()
191
 
 
7
  import gradio as gr
8
  from bertopic.representation import (
9
  KeyBERTInspired,
 
10
  TextGeneration,
11
  )
12
  from umap import UMAP
 
18
  AutoModelForCausalLM,
19
  pipeline,
20
  )
21
+ from prompts import REPRESENTATION_PROMPT
 
22
  from hdbscan import HDBSCAN
23
  from sklearn.feature_extraction.text import CountVectorizer
24
 
 
34
  session = requests.Session()
35
  sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
36
  keybert = KeyBERTInspired()
 
37
  vectorizer_model = CountVectorizer(stop_words="english")
38
 
39
  model_id = "meta-llama/Llama-2-7b-chat-hf"
 
49
 
50
  tokenizer = AutoTokenizer.from_pretrained(model_id)
51
 
 
52
  model = AutoModelForCausalLM.from_pretrained(
53
  model_id,
54
  trust_remote_code=True,
 
64
  max_new_tokens=500,
65
  repetition_penalty=1.1,
66
  )
 
67
 
68
+ llama2 = TextGeneration(generator, prompt=REPRESENTATION_PROMPT)
69
  representation_model = {
70
  "KeyBERT": keybert,
71
  "Llama2": llama2,
 
72
  }
73
 
74
  umap_model = UMAP(
 
126
  verbose=True,
127
  min_topic_size=15,
128
  )
129
+ logging.debug("Fitting new model")
130
  new_model.fit(docs, embeddings)
131
+ logging.debug("End fitting new model")
132
 
133
  if base_model is None:
134
  return new_model, new_model
 
151
  offset = 0
152
  base_model = None
153
  all_docs = []
154
+ reduced_embeddings_list = []
155
+
156
+ while offset < limit:
157
  docs = get_docs_from_parquet(parquet_urls, column, offset, chunk_size)
158
+ if not docs:
159
+ break
160
+
161
  logging.info(
162
+ f"----> Processing chunk: {offset=} {chunk_size=} with {len(docs)} docs"
163
  )
164
+
165
  embeddings = calculate_embeddings(docs)
 
 
 
166
  base_model, _ = fit_model(base_model, docs, embeddings)
167
  llama2_labels = [
168
  label[0][0].split("\n")[0]
169
  for label in base_model.get_topics(full=True)["Llama2"].values()
170
  ]
 
171
  base_model.set_topic_labels(llama2_labels)
172
 
173
  reduced_embeddings = reduce_umap_model.fit_transform(embeddings)
174
+ reduced_embeddings_list.append(reduced_embeddings)
175
 
176
  all_docs.extend(docs)
177
+
178
  topics_info = base_model.get_topic_info()
179
  topic_plot = base_model.visualize_documents(
180
+ all_docs,
181
+ reduced_embeddings=np.vstack(reduced_embeddings_list),
182
+ custom_labels=True,
183
  )
184
+
185
+ logging.info(f"Topics: {llama2_labels}")
186
+
187
  yield topics_info, topic_plot
188
 
189
+ offset += chunk_size
190
+
191
  logging.info("Finished processing all data")
192
  return base_model.get_topic_info(), base_model.visualize_topics()
193
 
prompts.py CHANGED
@@ -1,10 +1,10 @@
1
- system_prompt = """
2
  <s>[INST] <<SYS>>
3
  You are a helpful, respectful and honest assistant for labeling topics.
4
  <</SYS>>
5
  """
6
 
7
- example_prompt = """
8
  I have a topic that contains the following documents:
9
  - Traditional diets in most cultures were primarily plant-based with a little meat on top, but with the rise of industrial style meat production and factory farming, meat has become a staple food.
10
  - Meat, but especially beef, is the word food in terms of emissions.
@@ -17,7 +17,7 @@ Based on the information about the topic above, please create a short label of t
17
  [/INST] Environmental impacts of eating meat
18
  """
19
 
20
- main_prompt = """
21
  [INST]
22
  I have a topic that contains the following documents:
23
  [DOCUMENTS]
@@ -27,3 +27,5 @@ The topic is described by the following keywords: '[KEYWORDS]'.
27
  Based on the information about the topic above, please create a short label of this topic. Make sure you to only return the label and nothing more.
28
  [/INST]
29
  """
 
 
 
1
+ SYSTEM_PROMPT = """
2
  <s>[INST] <<SYS>>
3
  You are a helpful, respectful and honest assistant for labeling topics.
4
  <</SYS>>
5
  """
6
 
7
+ EXAMPLE_PROMPT = """
8
  I have a topic that contains the following documents:
9
  - Traditional diets in most cultures were primarily plant-based with a little meat on top, but with the rise of industrial style meat production and factory farming, meat has become a staple food.
10
  - Meat, but especially beef, is the word food in terms of emissions.
 
17
  [/INST] Environmental impacts of eating meat
18
  """
19
 
20
+ MAIN_PROMPT = """
21
  [INST]
22
  I have a topic that contains the following documents:
23
  [DOCUMENTS]
 
27
  Based on the information about the topic above, please create a short label of this topic. Make sure you to only return the label and nothing more.
28
  [/INST]
29
  """
30
+
31
+ REPRESENTATION_PROMPT = SYSTEM_PROMPT + EXAMPLE_PROMPT + MAIN_PROMPT