Spaces:
Running
Running
Added clean data options, improved re-representation options and visualisation. General format changes
Browse files- .gitignore +2 -0
- app.py +32 -528
- funcs/anonymiser.py +44 -10
- funcs/clean_funcs.py +107 -0
- funcs/helper_functions.py +29 -1
- funcs/presidio_analyzer_custom.py +114 -0
- funcs/representation_model.py +22 -22
- funcs/topic_core_funcs.py +500 -0
- requirements.txt +4 -1
.gitignore
CHANGED
@@ -2,6 +2,8 @@
|
|
2 |
*.ipynb
|
3 |
*.npz
|
4 |
*.csv
|
|
|
|
|
5 |
*.pkl
|
6 |
*.parquet
|
7 |
*.png
|
|
|
2 |
*.ipynb
|
3 |
*.npz
|
4 |
*.csv
|
5 |
+
*.xlsx
|
6 |
+
*.xls
|
7 |
*.pkl
|
8 |
*.parquet
|
9 |
*.png
|
app.py
CHANGED
@@ -1,516 +1,13 @@
|
|
1 |
-
import os
|
2 |
-
|
3 |
# Dendrograms will not work with the latest version of scipy (1.12.0), so installing the version prior to be safe
|
4 |
-
os.system("pip install scipy==1.11.4")
|
5 |
|
6 |
import gradio as gr
|
7 |
-
from datetime import datetime
|
8 |
import pandas as pd
|
9 |
import numpy as np
|
10 |
-
import time
|
11 |
|
12 |
-
from
|
|
|
13 |
from sklearn.feature_extraction.text import CountVectorizer
|
14 |
-
from sklearn.pipeline import make_pipeline
|
15 |
-
from sklearn.decomposition import TruncatedSVD
|
16 |
-
from sklearn.feature_extraction.text import TfidfVectorizer
|
17 |
-
import funcs.anonymiser as anon
|
18 |
-
from umap import UMAP
|
19 |
-
|
20 |
-
from torch import cuda, backends, version
|
21 |
-
|
22 |
-
# Default seed, can be changed in number selection on options page
|
23 |
-
random_seed = 42
|
24 |
-
|
25 |
-
# Check for torch cuda
|
26 |
-
# If you want to disable cuda for testing purposes
|
27 |
-
#os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
|
28 |
-
|
29 |
-
print("Is CUDA enabled? ", cuda.is_available())
|
30 |
-
print("Is a CUDA device available on this computer?", backends.cudnn.enabled)
|
31 |
-
if cuda.is_available():
|
32 |
-
torch_device = "gpu"
|
33 |
-
print("Cuda version installed is: ", version.cuda)
|
34 |
-
low_resource_mode = "No"
|
35 |
-
#os.system("nvidia-smi")
|
36 |
-
else:
|
37 |
-
torch_device = "cpu"
|
38 |
-
low_resource_mode = "Yes"
|
39 |
-
|
40 |
-
print("Device used is: ", torch_device)
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
from bertopic import BERTopic
|
45 |
-
|
46 |
-
|
47 |
-
today = datetime.now().strftime("%d%m%Y")
|
48 |
-
today_rev = datetime.now().strftime("%Y%m%d")
|
49 |
-
|
50 |
-
from funcs.helper_functions import dummy_function, initial_file_load, read_file, zip_folder, delete_files_in_folder, save_topic_outputs
|
51 |
-
#from funcs.representation_model import representation_model
|
52 |
-
from funcs.embeddings import make_or_load_embeddings
|
53 |
-
|
54 |
-
# Log terminal output: https://github.com/gradio-app/gradio/issues/2362
|
55 |
-
import sys
|
56 |
-
|
57 |
-
class Logger:
|
58 |
-
def __init__(self, filename):
|
59 |
-
self.terminal = sys.stdout
|
60 |
-
self.log = open(filename, "w")
|
61 |
-
|
62 |
-
def write(self, message):
|
63 |
-
self.terminal.write(message)
|
64 |
-
self.log.write(message)
|
65 |
-
|
66 |
-
def flush(self):
|
67 |
-
self.terminal.flush()
|
68 |
-
self.log.flush()
|
69 |
-
|
70 |
-
def isatty(self):
|
71 |
-
return False
|
72 |
-
|
73 |
-
sys.stdout = Logger("output.log")
|
74 |
-
|
75 |
-
def read_logs():
|
76 |
-
sys.stdout.flush()
|
77 |
-
with open("output.log", "r") as f:
|
78 |
-
return f.read()
|
79 |
-
|
80 |
-
# Load embeddings
|
81 |
-
embeddings_name = "BAAI/bge-small-en-v1.5" #"jinaai/jina-embeddings-v2-base-en"
|
82 |
-
|
83 |
-
# Use of Jina deprecated - kept here for posterity
|
84 |
-
# Pinning a Jina revision for security purposes: https://www.baseten.co/blog/pinning-ml-model-revisions-for-compatibility-and-security/
|
85 |
-
# Save Jina model locally as described here: https://huggingface.co/jinaai/jina-embeddings-v2-base-en/discussions/29
|
86 |
-
# local_embeddings_location = "model/jina/"
|
87 |
-
#revision_choice = "b811f03af3d4d7ea72a7c25c802b21fc675a5d99"
|
88 |
-
#revision_choice = "69d43700292701b06c24f43b96560566a4e5ad1f"
|
89 |
-
|
90 |
-
# Model used for representing topics
|
91 |
-
hf_model_name = 'second-state/stablelm-2-zephyr-1.6b-GGUF' #'TheBloke/phi-2-orange-GGUF' #'NousResearch/Nous-Capybara-7B-V1.9-GGUF'
|
92 |
-
hf_model_file = 'stablelm-2-zephyr-1_6b-Q5_K_M.gguf' # 'phi-2-orange.Q5_K_M.gguf' #'Capybara-7B-V1.9-Q5_K_M.gguf'
|
93 |
-
|
94 |
-
def extract_topics(data, in_files, min_docs_slider, in_colnames, max_topics_slider, candidate_topics, data_file_name_no_ext, custom_labels_df, anonymise_drop, return_intermediate_files, embeddings_super_compress, low_resource_mode, save_topic_model, embeddings_out, zero_shot_similarity, random_seed, calc_probs, progress=gr.Progress(track_tqdm=True)):
|
95 |
-
|
96 |
-
progress(0, desc= "Loading data")
|
97 |
-
|
98 |
-
if calc_probs == "No":
|
99 |
-
calc_probs = False
|
100 |
-
elif calc_probs == "Yes":
|
101 |
-
print("Calculating all probabilities.")
|
102 |
-
calc_probs == True
|
103 |
-
|
104 |
-
if not in_colnames:
|
105 |
-
error_message = "Please enter one column name to use to find topics."
|
106 |
-
print(error_message)
|
107 |
-
return error_message, None, embeddings_out, data_file_name_no_ext, None, None
|
108 |
-
|
109 |
-
all_tic = time.perf_counter()
|
110 |
-
|
111 |
-
output_list = []
|
112 |
-
file_list = [string.name for string in in_files]
|
113 |
-
|
114 |
-
in_colnames_list_first = in_colnames[0]
|
115 |
-
|
116 |
-
docs = list(data[in_colnames_list_first].str.lower())
|
117 |
-
|
118 |
-
if anonymise_drop == "Yes":
|
119 |
-
progress(0.1, desc= "Anonymising data")
|
120 |
-
anon_tic = time.perf_counter()
|
121 |
-
|
122 |
-
data_anon_col, anonymisation_success = anon.anonymise_script(data, in_colnames_list_first, anon_strat="replace")
|
123 |
-
data[in_colnames_list_first] = data_anon_col[in_colnames_list_first]
|
124 |
-
anonymise_data_name = data_file_name_no_ext + "_anonymised_" + today_rev + ".csv"
|
125 |
-
data.to_csv(anonymise_data_name)
|
126 |
-
output_list.append(anonymise_data_name)
|
127 |
-
|
128 |
-
print(anonymisation_success)
|
129 |
-
|
130 |
-
anon_toc = time.perf_counter()
|
131 |
-
time_out = f"Anonymising text took {anon_toc - anon_tic:0.1f} seconds"
|
132 |
-
|
133 |
-
# Check if embeddings are being loaded in
|
134 |
-
progress(0.2, desc= "Loading/creating embeddings")
|
135 |
-
|
136 |
-
print("Low resource mode: ", low_resource_mode)
|
137 |
-
|
138 |
-
if low_resource_mode == "No":
|
139 |
-
print("Using high resource BGE transformer model")
|
140 |
-
|
141 |
-
embedding_model = SentenceTransformer(embeddings_name)
|
142 |
-
|
143 |
-
# Use of Jina now superseded by BGE, keeping this code just in case I consider reverting one day
|
144 |
-
#try:
|
145 |
-
#embedding_model = AutoModel.from_pretrained(embeddings_name, revision = revision_choice, trust_remote_code=True,device_map="auto") # For Jina
|
146 |
-
#except:
|
147 |
-
# embedding_model = AutoModel.from_pretrained(embeddings_name)#, revision = revision_choice, trust_remote_code=True, device_map="auto", use_auth_token=os.environ["HF_TOKEN"])
|
148 |
-
#tokenizer = AutoTokenizer.from_pretrained(embeddings_name)
|
149 |
-
#embedding_model_pipe = pipeline("feature-extraction", model=embedding_model, tokenizer=tokenizer)
|
150 |
-
|
151 |
-
# UMAP model uses Bertopic defaults
|
152 |
-
umap_model = UMAP(n_neighbors=15, n_components=5, min_dist=0.0, metric='cosine', low_memory=False, random_state=random_seed)
|
153 |
-
|
154 |
-
elif low_resource_mode == "Yes":
|
155 |
-
print("Choosing low resource TF-IDF model.")
|
156 |
-
|
157 |
-
embedding_model_pipe = make_pipeline(
|
158 |
-
TfidfVectorizer(),
|
159 |
-
TruncatedSVD(100) # 100 # To be compatible with zero shot, this needs to be lower than number of suggested topics
|
160 |
-
)
|
161 |
-
embedding_model = embedding_model_pipe
|
162 |
-
|
163 |
-
umap_model = TruncatedSVD(n_components=5, random_state=random_seed)
|
164 |
-
|
165 |
-
embeddings_out = make_or_load_embeddings(docs, file_list, embeddings_out, embedding_model, embeddings_super_compress, low_resource_mode)
|
166 |
-
|
167 |
-
vectoriser_model = CountVectorizer(stop_words="english", ngram_range=(1, 2), min_df=0.1)
|
168 |
-
|
169 |
-
# Representation model not currently used in this function
|
170 |
-
#print("Create Keybert-like topic representations by default")
|
171 |
-
#from funcs.representation_model import create_representation_model, llm_config, chosen_start_tag
|
172 |
-
#representation_model = create_representation_model("No", llm_config, hf_model_name, hf_model_file, chosen_start_tag, low_resource_mode)
|
173 |
-
|
174 |
-
|
175 |
-
progress(0.3, desc= "Embeddings loaded. Creating BERTopic model")
|
176 |
-
|
177 |
-
if not candidate_topics:
|
178 |
-
|
179 |
-
topic_model = BERTopic( embedding_model=embedding_model, #embedding_model_pipe, #for Jina
|
180 |
-
vectorizer_model=vectoriser_model,
|
181 |
-
umap_model=umap_model,
|
182 |
-
min_topic_size = min_docs_slider,
|
183 |
-
nr_topics = max_topics_slider,
|
184 |
-
calculate_probabilities=calc_probs,
|
185 |
-
#representation_model=representation_model,
|
186 |
-
verbose = True)
|
187 |
-
|
188 |
-
assigned_topics, probs = topic_model.fit_transform(docs, embeddings_out)
|
189 |
-
|
190 |
-
#print(assigned_topics)
|
191 |
-
|
192 |
-
# Replace original labels with Keybert labels
|
193 |
-
#if "KeyBERT" in topic_model.get_topic_info().columns:
|
194 |
-
# keybert_labels = [f"{i+1}: {', '.join(entry[:5])}" for i, entry in enumerate(topic_model.get_topics(full=True)["KeyBERT"].values())]
|
195 |
-
# topic_model.set_topic_labels(keybert_labels)
|
196 |
-
|
197 |
-
|
198 |
-
# Do this if you have pre-defined topics
|
199 |
-
else:
|
200 |
-
if low_resource_mode == "Yes":
|
201 |
-
error_message = "Zero shot topic modelling currently not compatible with low-resource embeddings. Please change this option to 'No' on the options tab and retry."
|
202 |
-
print(error_message)
|
203 |
-
|
204 |
-
return error_message, output_list, embeddings_out, data_file_name_no_ext, None, docs
|
205 |
-
|
206 |
-
zero_shot_topics = read_file(candidate_topics.name)
|
207 |
-
zero_shot_topics_lower = list(zero_shot_topics.iloc[:, 0].str.lower())
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
topic_model = BERTopic( embedding_model=embedding_model, #embedding_model_pipe, # for Jina
|
212 |
-
vectorizer_model=vectoriser_model,
|
213 |
-
umap_model=umap_model,
|
214 |
-
min_topic_size = min_docs_slider,
|
215 |
-
nr_topics = max_topics_slider,
|
216 |
-
zeroshot_topic_list = zero_shot_topics_lower,
|
217 |
-
zeroshot_min_similarity = zero_shot_similarity, # 0.7
|
218 |
-
calculate_probabilities=calc_probs,
|
219 |
-
#representation_model=representation_model,
|
220 |
-
verbose = True)
|
221 |
-
|
222 |
-
assigned_topics, probs = topic_model.fit_transform(docs, embeddings_out)
|
223 |
-
|
224 |
-
# For some reason, zero topic modelling exports assigned topics as a np.array instead of a list. Converting it back here.
|
225 |
-
if isinstance(assigned_topics, np.ndarray):
|
226 |
-
assigned_topics = assigned_topics.tolist()
|
227 |
-
#print(assigned_topics.tolist())
|
228 |
-
|
229 |
-
# Zero shot modelling is a model merge, which wipes the c_tf_idf part of the resulting model completely. To get hierarchical modelling to work, we need to recreate this part of the model with the CountVectorizer options used to create the initial model. Since with zero shot, we are merging two models that have exactly the same set of documents, the vocubulary should be the same, and so recreating the cf_tf_idf component in this way shouldn't be a problem. Discussion here, and below based on Maarten's suggested code: https://github.com/MaartenGr/BERTopic/issues/1700
|
230 |
-
|
231 |
-
doc_dets = topic_model.get_document_info(docs)
|
232 |
-
|
233 |
-
documents_per_topic = doc_dets.groupby(['Topic'], as_index=False).agg({'Document': ' '.join})
|
234 |
-
|
235 |
-
# Assign CountVectorizer to merged model
|
236 |
-
|
237 |
-
topic_model.vectorizer_model = vectoriser_model
|
238 |
-
|
239 |
-
# Re-calculate c-TF-IDF
|
240 |
-
c_tf_idf, _ = topic_model._c_tf_idf(documents_per_topic)
|
241 |
-
topic_model.c_tf_idf_ = c_tf_idf
|
242 |
-
|
243 |
-
# Replace original labels with Keybert labels
|
244 |
-
#if "KeyBERT" in topic_model.get_topic_info().columns:
|
245 |
-
# print(topic_model.get_topics(full=True)["KeyBERT"].values())
|
246 |
-
# keybert_labels = [f"{i+1}: {', '.join(entry[:5])}" for i, entry in enumerate(topic_model.get_topics(full=True)["KeyBERT"].values())]
|
247 |
-
# topic_model.set_topic_labels(keybert_labels)
|
248 |
-
|
249 |
-
if not assigned_topics:
|
250 |
-
# Handle the empty array case
|
251 |
-
return "No topics found.", output_list, embeddings_out, data_file_name_no_ext, topic_model, docs
|
252 |
-
|
253 |
-
else:
|
254 |
-
print("Topic model created.")
|
255 |
-
|
256 |
-
# Replace current topic labels if new ones loaded in
|
257 |
-
if not custom_labels_df.empty:
|
258 |
-
#custom_label_list = list(custom_labels_df.iloc[:,0])
|
259 |
-
custom_label_list = [label.replace("\n", "") for label in custom_labels_df.iloc[:,0]]
|
260 |
-
|
261 |
-
topic_model.set_topic_labels(custom_label_list)
|
262 |
-
#topic_model.update_topics(docs, topics=assigned_topics, vectorizer_model=vectoriser_model)
|
263 |
-
|
264 |
-
|
265 |
-
print("Custom topics: ", topic_model.custom_labels_)
|
266 |
-
|
267 |
-
# Outputs
|
268 |
-
output_list, output_text = save_topic_outputs(topic_model, data_file_name_no_ext, output_list, docs, save_topic_model)
|
269 |
-
|
270 |
-
# If you want to save your embedding files
|
271 |
-
if return_intermediate_files == "Yes":
|
272 |
-
print("Saving embeddings to file")
|
273 |
-
if low_resource_mode == "Yes":
|
274 |
-
embeddings_file_name = data_file_name_no_ext + '_' + 'tfidf_embeddings.npz'
|
275 |
-
else:
|
276 |
-
if embeddings_super_compress == "No":
|
277 |
-
embeddings_file_name = data_file_name_no_ext + '_' + 'bge_embeddings.npz'
|
278 |
-
else:
|
279 |
-
embeddings_file_name = data_file_name_no_ext + '_' + 'bge_embeddings_compress.npz'
|
280 |
-
|
281 |
-
np.savez_compressed(embeddings_file_name, embeddings_out)
|
282 |
-
|
283 |
-
output_list.append(embeddings_file_name)
|
284 |
-
|
285 |
-
all_toc = time.perf_counter()
|
286 |
-
time_out = f"All processes took {all_toc - all_tic:0.1f} seconds."
|
287 |
-
print(time_out)
|
288 |
-
|
289 |
-
return output_text, output_list, embeddings_out, data_file_name_no_ext, topic_model, docs
|
290 |
-
|
291 |
-
def reduce_outliers(topic_model, docs, embeddings_out, data_file_name_no_ext, save_topic_model, progress=gr.Progress(track_tqdm=True)):
|
292 |
-
|
293 |
-
progress(0, desc= "Preparing data")
|
294 |
-
|
295 |
-
output_list = []
|
296 |
-
|
297 |
-
all_tic = time.perf_counter()
|
298 |
-
|
299 |
-
assigned_topics, probs = topic_model.fit_transform(docs, embeddings_out)
|
300 |
-
|
301 |
-
if isinstance(assigned_topics, np.ndarray):
|
302 |
-
assigned_topics = assigned_topics.tolist()
|
303 |
-
|
304 |
-
#progress(0.2, desc= "Loading in representation model")
|
305 |
-
#print("Create LLM topic labels:", create_llm_topic_labels)
|
306 |
-
#from funcs.representation_model import create_representation_model, llm_config, chosen_start_tag
|
307 |
-
#representation_model = create_representation_model(create_llm_topic_labels, llm_config, hf_model_name, hf_model_file, chosen_start_tag, low_resource_mode)
|
308 |
-
|
309 |
-
# Reduce outliers if required, then update representation
|
310 |
-
progress(0.2, desc= "Reducing outliers")
|
311 |
-
print("Reducing outliers.")
|
312 |
-
# Calculate the c-TF-IDF representation for each outlier document and find the best matching c-TF-IDF topic representation using cosine similarity.
|
313 |
-
assigned_topics = topic_model.reduce_outliers(docs, assigned_topics, strategy="embeddings")
|
314 |
-
# Then, update the topics to the ones that considered the new data
|
315 |
-
|
316 |
-
print("Finished reducing outliers.")
|
317 |
-
|
318 |
-
progress(0.7, desc= "Replacing topic names with LLMs if necessary")
|
319 |
-
#print("Create LLM topic labels:", "No")
|
320 |
-
#vectoriser_model = CountVectorizer(stop_words="english", ngram_range=(1, 2), min_df=0.1)
|
321 |
-
#representation_model = create_representation_model("No", llm_config, hf_model_name, hf_model_file, chosen_start_tag, low_resource_mode)
|
322 |
-
#topic_model.update_topics(docs, topics=assigned_topics, vectorizer_model=vectoriser_model, representation_model=representation_model)
|
323 |
-
|
324 |
-
topic_dets = topic_model.get_topic_info()
|
325 |
-
|
326 |
-
# Replace original labels with LLM labels
|
327 |
-
if "LLM" in topic_model.get_topic_info().columns:
|
328 |
-
llm_labels = [label[0][0].split("\n")[0] for label in topic_model.get_topics(full=True)["LLM"].values()]
|
329 |
-
topic_model.set_topic_labels(llm_labels)
|
330 |
-
else:
|
331 |
-
topic_model.set_topic_labels(list(topic_dets["Name"]))
|
332 |
-
|
333 |
-
# Outputs
|
334 |
-
progress(0.9, desc= "Saving to file")
|
335 |
-
output_list, output_text = save_topic_outputs(topic_model, data_file_name_no_ext, output_list, docs, save_topic_model)
|
336 |
-
|
337 |
-
all_toc = time.perf_counter()
|
338 |
-
time_out = f"All processes took {all_toc - all_tic:0.1f} seconds"
|
339 |
-
print(time_out)
|
340 |
-
|
341 |
-
return output_text, output_list, topic_model
|
342 |
-
|
343 |
-
def represent_topics(topic_model, docs, embeddings_out, data_file_name_no_ext, low_resource_mode, save_topic_model, progress=gr.Progress(track_tqdm=True)):
|
344 |
-
#from funcs.prompts import capybara_prompt, capybara_start, open_hermes_prompt, open_hermes_start, stablelm_prompt, stablelm_start
|
345 |
-
from funcs.representation_model import create_representation_model, llm_config, chosen_start_tag
|
346 |
-
|
347 |
-
output_list = []
|
348 |
-
|
349 |
-
all_tic = time.perf_counter()
|
350 |
-
|
351 |
-
vectoriser_model = CountVectorizer(stop_words="english", ngram_range=(1, 2), min_df=0.1)
|
352 |
-
|
353 |
-
assigned_topics, probs = topic_model.fit_transform(docs, embeddings_out)
|
354 |
-
|
355 |
-
topic_dets = topic_model.get_topic_info()
|
356 |
-
|
357 |
-
progress(0.1, desc= "Loading LLM model")
|
358 |
-
print("Create LLM topic labels:", "Yes")
|
359 |
-
representation_model = create_representation_model("Yes", llm_config, hf_model_name, hf_model_file, chosen_start_tag, low_resource_mode)
|
360 |
-
|
361 |
-
topic_model.update_topics(docs, topics=assigned_topics, vectorizer_model=vectoriser_model, representation_model=representation_model)
|
362 |
-
|
363 |
-
# Replace original labels with LLM labels
|
364 |
-
if "LLM" in topic_model.get_topic_info().columns:
|
365 |
-
llm_labels = [label[0][0].split("\n")[0] for label in topic_model.get_topics(full=True)["LLM"].values()]
|
366 |
-
topic_model.set_topic_labels(llm_labels)
|
367 |
-
|
368 |
-
label_list_file_name = data_file_name_no_ext + '_llm_topic_list_' + today_rev + '.csv'
|
369 |
-
|
370 |
-
llm_labels_df = pd.DataFrame(data={"Label":llm_labels})
|
371 |
-
llm_labels_df.to_csv(label_list_file_name, index=None)
|
372 |
-
#with open(label_list_file_name, 'w') as file:
|
373 |
-
# file.write(f"Label\n")
|
374 |
-
# for item in llm_labels:
|
375 |
-
# file.write(f"{item}\n")
|
376 |
-
output_list.append(label_list_file_name)
|
377 |
-
else:
|
378 |
-
topic_model.set_topic_labels(list(topic_dets["Name"]))
|
379 |
-
|
380 |
-
# Outputs
|
381 |
-
progress(0.8, desc= "Saving outputs")
|
382 |
-
output_list, output_text = save_topic_outputs(topic_model, data_file_name_no_ext, output_list, docs, save_topic_model)
|
383 |
-
|
384 |
-
all_toc = time.perf_counter()
|
385 |
-
time_out = f"All processes took {all_toc - all_tic:0.1f} seconds"
|
386 |
-
print(time_out)
|
387 |
-
|
388 |
-
return output_text, output_list, topic_model
|
389 |
-
|
390 |
-
def visualise_topics(topic_model, data, data_file_name_no_ext, low_resource_mode, embeddings_out, in_label, in_colnames, legend_label, sample_prop, visualisation_type_radio, random_seed, progress=gr.Progress()):
|
391 |
-
|
392 |
-
progress(0, desc= "Preparing data for visualisation")
|
393 |
-
|
394 |
-
output_list = []
|
395 |
-
vis_tic = time.perf_counter()
|
396 |
-
|
397 |
-
from funcs.bertopic_vis_documents import visualize_documents_custom, visualize_hierarchical_documents_custom, visualize_barchart_custom
|
398 |
-
|
399 |
-
if not visualisation_type_radio:
|
400 |
-
return "Please choose a visualisation type above.", output_list, None, None
|
401 |
-
|
402 |
-
# Get topic labels
|
403 |
-
if in_label:
|
404 |
-
in_label_list_first = in_label[0]
|
405 |
-
else:
|
406 |
-
return "Label column not found. Please enter this above.", output_list, None, None
|
407 |
-
|
408 |
-
# Get docs
|
409 |
-
if in_colnames:
|
410 |
-
in_colnames_list_first = in_colnames[0]
|
411 |
-
else:
|
412 |
-
return "Label column not found. Please enter this on the data load tab.", output_list, None, None
|
413 |
-
|
414 |
-
docs = list(data[in_colnames_list_first].str.lower())
|
415 |
-
|
416 |
-
# Make sure format of input series is good
|
417 |
-
data[in_label_list_first] = data[in_label_list_first].fillna('').astype(str)
|
418 |
-
label_list = list(data[in_label_list_first])
|
419 |
-
|
420 |
-
topic_dets = topic_model.get_topic_info()
|
421 |
-
|
422 |
-
# Replace original labels with another representation if specified
|
423 |
-
if legend_label:
|
424 |
-
topic_dets = topic_model.get_topics(full=True)
|
425 |
-
if legend_label in topic_dets:
|
426 |
-
labels = [topic_dets[legend_label].values()]
|
427 |
-
labels = [str(v) for v in labels]
|
428 |
-
topic_model.set_topic_labels(labels)
|
429 |
-
|
430 |
-
# Pre-reduce embeddings for visualisation purposes
|
431 |
-
if low_resource_mode == "No":
|
432 |
-
reduced_embeddings = UMAP(n_neighbors=15, n_components=2, min_dist=0.0, metric='cosine', random_state=random_seed).fit_transform(embeddings_out)
|
433 |
-
else:
|
434 |
-
reduced_embeddings = TruncatedSVD(2, random_state=random_seed).fit_transform(embeddings_out)
|
435 |
-
|
436 |
-
progress(0.5, desc= "Creating visualisation (this can take a while)")
|
437 |
-
# Visualise the topics:
|
438 |
-
|
439 |
-
print("Creating visualisation")
|
440 |
-
|
441 |
-
# "Topic document graph", "Hierarchical view"
|
442 |
-
|
443 |
-
if visualisation_type_radio == "Topic document graph":
|
444 |
-
topics_vis = visualize_documents_custom(topic_model, docs, hover_labels = label_list, reduced_embeddings=reduced_embeddings, hide_annotations=True, hide_document_hover=False, custom_labels=True, sample = sample_prop, width= 1200, height = 750)
|
445 |
-
|
446 |
-
topics_vis_name = data_file_name_no_ext + '_' + 'vis_topic_docs_' + today_rev + '.html'
|
447 |
-
topics_vis.write_html(topics_vis_name)
|
448 |
-
output_list.append(topics_vis_name)
|
449 |
-
|
450 |
-
topics_vis_2 = visualize_barchart_custom(topic_model, top_n_topics = 12, custom_labels=True, width= 300, height = 250)
|
451 |
-
|
452 |
-
topics_vis_2_name = data_file_name_no_ext + '_' + 'vis_barchart_' + today_rev + '.html'
|
453 |
-
topics_vis_2.write_html(topics_vis_2_name)
|
454 |
-
output_list.append(topics_vis_2_name)
|
455 |
-
|
456 |
-
elif visualisation_type_radio == "Hierarchical view":
|
457 |
-
|
458 |
-
# Check that original topics are retained
|
459 |
-
#new_topic_dets = topic_model.get_topic_info()
|
460 |
-
#new_topic_dets.to_csv("new_topic_dets.csv")
|
461 |
-
|
462 |
-
#from funcs.bertopic_hierarchical_topics_mod import hierarchical_topics_mod
|
463 |
-
|
464 |
-
hierarchical_topics = topic_model.hierarchical_topics(docs)
|
465 |
-
|
466 |
-
# Save new hierarchical topic model to file
|
467 |
-
hierarchical_topics_name = data_file_name_no_ext + '_' + 'vis_hierarchy_topics_' + today_rev + '.csv'
|
468 |
-
hierarchical_topics.to_csv(hierarchical_topics_name)
|
469 |
-
output_list.append(hierarchical_topics_name)
|
470 |
-
|
471 |
-
#hierarchical_topics = hierarchical_topics_mod(topic_model, docs)
|
472 |
-
topics_vis = visualize_hierarchical_documents_custom(topic_model, docs, label_list, hierarchical_topics, reduced_embeddings=reduced_embeddings, sample = sample_prop, hide_document_hover= False, custom_labels=True, width= 1200, height = 750)
|
473 |
-
#topics_vis = topic_model.visualize_hierarchical_documents(docs, hierarchical_topics, reduced_embeddings=reduced_embeddings, sample = sample_prop, hide_document_hover= False, custom_labels=True, width= 1200, height = 750)
|
474 |
-
topics_vis_2 = topic_model.visualize_hierarchy(hierarchical_topics=hierarchical_topics, width= 1200, height = 750)
|
475 |
-
|
476 |
-
topics_vis_name = data_file_name_no_ext + '_' + 'vis_hierarchy_topic_doc_' + today_rev + '.html'
|
477 |
-
topics_vis.write_html(topics_vis_name)
|
478 |
-
output_list.append(topics_vis_name)
|
479 |
-
|
480 |
-
topics_vis_2_name = data_file_name_no_ext + '_' + 'vis_hierarchy_' + today_rev + '.html'
|
481 |
-
topics_vis_2.write_html(topics_vis_2_name)
|
482 |
-
output_list.append(topics_vis_2_name)
|
483 |
-
|
484 |
-
all_toc = time.perf_counter()
|
485 |
-
time_out = f"Creating visualisation took {all_toc - vis_tic:0.1f} seconds"
|
486 |
-
print(time_out)
|
487 |
-
|
488 |
-
return time_out, output_list, topics_vis, topics_vis_2
|
489 |
-
|
490 |
-
def save_as_pytorch_model(topic_model, data_file_name_no_ext , progress=gr.Progress()):
|
491 |
-
|
492 |
-
if not topic_model:
|
493 |
-
return "No Pytorch model found.", None
|
494 |
-
|
495 |
-
progress(0, desc= "Saving topic model in Pytorch format")
|
496 |
-
|
497 |
-
output_list = []
|
498 |
-
|
499 |
-
|
500 |
-
topic_model_save_name_folder = "output_model/" + data_file_name_no_ext + "_topics_" + today_rev# + ".safetensors"
|
501 |
-
topic_model_save_name_zip = topic_model_save_name_folder + ".zip"
|
502 |
-
|
503 |
-
# Clear folder before replacing files
|
504 |
-
delete_files_in_folder(topic_model_save_name_folder)
|
505 |
-
|
506 |
-
topic_model.save(topic_model_save_name_folder, serialization='pytorch', save_embedding_model=True, save_ctfidf=False)
|
507 |
-
|
508 |
-
# Zip file example
|
509 |
-
|
510 |
-
zip_folder(topic_model_save_name_folder, topic_model_save_name_zip)
|
511 |
-
output_list.append(topic_model_save_name_zip)
|
512 |
-
|
513 |
-
return "Model saved in Pytorch format.", output_list
|
514 |
|
515 |
# Gradio app
|
516 |
|
@@ -524,6 +21,7 @@ with block:
|
|
524 |
docs_state = gr.State()
|
525 |
data_file_name_no_ext_state = gr.State()
|
526 |
label_list_state = gr.State(pd.DataFrame())
|
|
|
527 |
|
528 |
gr.Markdown(
|
529 |
"""
|
@@ -539,6 +37,13 @@ with block:
|
|
539 |
with gr.Row():
|
540 |
in_colnames = gr.Dropdown(choices=["Choose a column"], multiselect = True, label="Select column to find topics (first will be chosen if multiple selected).")
|
541 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
542 |
with gr.Accordion("I have my own list of topics (zero shot topic modelling).", open = False):
|
543 |
candidate_topics = gr.File(label="Input topics from file (csv). File should have at least one column with a header and topic keywords in cells below. Topics will be taken from the first column of the file. Currently not compatible with low-resource embeddings.")
|
544 |
zero_shot_similarity = gr.Slider(minimum = 0.5, maximum = 1, value = 0.65, step = 0.001, label = "Minimum similarity value for document to be assigned to zero-shot topic.")
|
@@ -548,20 +53,20 @@ with block:
|
|
548 |
max_topics_slider = gr.Slider(minimum = 2, maximum = 500, value = 10, step = 1, label = "Maximum number of topics")
|
549 |
|
550 |
with gr.Row():
|
551 |
-
topics_btn = gr.Button("Extract topics")
|
552 |
|
553 |
with gr.Row():
|
554 |
output_single_text = gr.Textbox(label="Output topics")
|
555 |
output_file = gr.File(label="Output file")
|
556 |
|
557 |
with gr.Accordion("Post processing options.", open = True):
|
|
|
|
|
|
|
558 |
with gr.Row():
|
559 |
reduce_outliers_btn = gr.Button("Reduce outliers")
|
560 |
-
represent_llm_btn = gr.Button("Generate topic labels with LLMs")
|
561 |
save_pytorch_btn = gr.Button("Save model in Pytorch format")
|
562 |
-
|
563 |
-
#logs = gr.Textbox(label="Processing logs.")
|
564 |
-
|
565 |
with gr.Tab("Visualise"):
|
566 |
with gr.Row():
|
567 |
visualisation_type_radio = gr.Radio(label="Visualisation type", choices=["Topic document graph", "Hierarchical view"])
|
@@ -575,40 +80,39 @@ with block:
|
|
575 |
out_plot_file = gr.File(label="Output plots to file", file_count="multiple")
|
576 |
plot = gr.Plot(label="Visualise your topics here.")
|
577 |
plot_2 = gr.Plot(label="Visualise your topics here.")
|
578 |
-
|
579 |
|
580 |
with gr.Tab("Options"):
|
581 |
with gr.Accordion("Data load and processing options", open = True):
|
582 |
with gr.Row():
|
583 |
-
anonymise_drop = gr.Dropdown(value = "No", choices=["Yes", "No"], multiselect=False, label="Anonymise data on file load. Names and other details are replaced with tags e.g. '<person>'.")
|
584 |
-
embedding_super_compress = gr.Dropdown(label = "Round embeddings to three dp for smaller files with less accuracy.", value="No", choices=["Yes", "No"])
|
585 |
seed_number = gr.Number(label="Random seed to use for dimensionality reduction.", minimum=0, step=1, value=42, precision=0)
|
586 |
-
calc_probs = gr.Dropdown(label="Calculate all topic probabilities
|
587 |
with gr.Row():
|
588 |
low_resource_mode_opt = gr.Dropdown(label = "Use low resource embeddings and processing.", value="No", choices=["Yes", "No"])
|
|
|
|
|
589 |
return_intermediate_files = gr.Dropdown(label = "Return intermediate processing files from file preparation.", value="Yes", choices=["Yes", "No"])
|
590 |
-
save_topic_model = gr.Dropdown(label = "Save topic model to file.", value="
|
591 |
|
592 |
-
# Update column names dropdown when file uploaded
|
593 |
in_files.upload(fn=initial_file_load, inputs=[in_files], outputs=[in_colnames, in_label, data_state, output_single_text, topic_model_state, embeddings_state, data_file_name_no_ext_state, label_list_state])
|
594 |
in_colnames.change(dummy_function, in_colnames, None)
|
595 |
|
596 |
-
|
|
|
|
|
|
|
|
|
597 |
|
|
|
598 |
reduce_outliers_btn.click(fn=reduce_outliers, inputs=[topic_model_state, docs_state, embeddings_state, data_file_name_no_ext_state, save_topic_model], outputs=[output_single_text, output_file, topic_model_state], api_name="reduce_outliers")
|
599 |
|
600 |
-
|
|
|
601 |
|
|
|
602 |
save_pytorch_btn.click(fn=save_as_pytorch_model, inputs=[topic_model_state, data_file_name_no_ext_state], outputs=[output_single_text, output_file])
|
603 |
|
|
|
604 |
plot_btn.click(fn=visualise_topics, inputs=[topic_model_state, data_state, data_file_name_no_ext_state, low_resource_mode_opt, embeddings_state, in_label, in_colnames, legend_label, sample_slide, visualisation_type_radio, seed_number], outputs=[vis_output_single_text, out_plot_file, plot, plot_2], api_name="plot")
|
605 |
|
606 |
-
|
607 |
-
|
608 |
-
block.queue().launch(debug=True)#, server_name="0.0.0.0", ssl_verify=False, server_port=7860)
|
609 |
-
|
610 |
-
|
611 |
-
|
612 |
-
|
613 |
-
|
614 |
-
|
|
|
|
|
|
|
1 |
# Dendrograms will not work with the latest version of scipy (1.12.0), so installing the version prior to be safe
|
2 |
+
#os.system("pip install scipy==1.11.4")
|
3 |
|
4 |
import gradio as gr
|
|
|
5 |
import pandas as pd
|
6 |
import numpy as np
|
|
|
7 |
|
8 |
+
from funcs.topic_core_funcs import pre_clean, extract_topics, reduce_outliers, represent_topics, visualise_topics, save_as_pytorch_model
|
9 |
+
from funcs.helper_functions import dummy_function, initial_file_load
|
10 |
from sklearn.feature_extraction.text import CountVectorizer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
# Gradio app
|
13 |
|
|
|
21 |
docs_state = gr.State()
|
22 |
data_file_name_no_ext_state = gr.State()
|
23 |
label_list_state = gr.State(pd.DataFrame())
|
24 |
+
vectoriser_state = gr.State(CountVectorizer(stop_words="english", ngram_range=(1, 2), min_df=0.1, max_df=0.95))
|
25 |
|
26 |
gr.Markdown(
|
27 |
"""
|
|
|
37 |
with gr.Row():
|
38 |
in_colnames = gr.Dropdown(choices=["Choose a column"], multiselect = True, label="Select column to find topics (first will be chosen if multiple selected).")
|
39 |
|
40 |
+
with gr.Accordion("Clean data", open = False):
|
41 |
+
with gr.Row():
|
42 |
+
clean_text = gr.Dropdown(value = "No", choices=["Yes", "No"], multiselect=False, label="Clean data - remove html, numbers with > 2 digits, emails, postcodes (UK).")
|
43 |
+
drop_duplicate_text = gr.Dropdown(value = "No", choices=["Yes", "No"], multiselect=False, label="Remove duplicate text, drop < 10 char strings. May make previous embedding files incompatible due to differing lengths.")
|
44 |
+
anonymise_drop = gr.Dropdown(value = "No", choices=["Yes", "No"], multiselect=False, label="Anonymise data on file load. Personal details are redacted - not 100% effective!")
|
45 |
+
clean_btn = gr.Button("Clean data")
|
46 |
+
|
47 |
with gr.Accordion("I have my own list of topics (zero shot topic modelling).", open = False):
|
48 |
candidate_topics = gr.File(label="Input topics from file (csv). File should have at least one column with a header and topic keywords in cells below. Topics will be taken from the first column of the file. Currently not compatible with low-resource embeddings.")
|
49 |
zero_shot_similarity = gr.Slider(minimum = 0.5, maximum = 1, value = 0.65, step = 0.001, label = "Minimum similarity value for document to be assigned to zero-shot topic.")
|
|
|
53 |
max_topics_slider = gr.Slider(minimum = 2, maximum = 500, value = 10, step = 1, label = "Maximum number of topics")
|
54 |
|
55 |
with gr.Row():
|
56 |
+
topics_btn = gr.Button("Extract topics", variant="primary")
|
57 |
|
58 |
with gr.Row():
|
59 |
output_single_text = gr.Textbox(label="Output topics")
|
60 |
output_file = gr.File(label="Output file")
|
61 |
|
62 |
with gr.Accordion("Post processing options.", open = True):
|
63 |
+
with gr.Row():
|
64 |
+
representation_type = gr.Dropdown(label = "Method for generating new topic labels", value="Default", choices=["Default", "MMR", "KeyBERT", "LLM"])
|
65 |
+
represent_llm_btn = gr.Button("Change topic labels")
|
66 |
with gr.Row():
|
67 |
reduce_outliers_btn = gr.Button("Reduce outliers")
|
|
|
68 |
save_pytorch_btn = gr.Button("Save model in Pytorch format")
|
69 |
+
|
|
|
|
|
70 |
with gr.Tab("Visualise"):
|
71 |
with gr.Row():
|
72 |
visualisation_type_radio = gr.Radio(label="Visualisation type", choices=["Topic document graph", "Hierarchical view"])
|
|
|
80 |
out_plot_file = gr.File(label="Output plots to file", file_count="multiple")
|
81 |
plot = gr.Plot(label="Visualise your topics here.")
|
82 |
plot_2 = gr.Plot(label="Visualise your topics here.")
|
|
|
83 |
|
84 |
with gr.Tab("Options"):
|
85 |
with gr.Accordion("Data load and processing options", open = True):
|
86 |
with gr.Row():
|
|
|
|
|
87 |
seed_number = gr.Number(label="Random seed to use for dimensionality reduction.", minimum=0, step=1, value=42, precision=0)
|
88 |
+
calc_probs = gr.Dropdown(label="Calculate all topic probabilities", value="No", choices=["Yes", "No"])
|
89 |
with gr.Row():
|
90 |
low_resource_mode_opt = gr.Dropdown(label = "Use low resource embeddings and processing.", value="No", choices=["Yes", "No"])
|
91 |
+
embedding_super_compress = gr.Dropdown(label = "Round embeddings to three dp for smaller files with less accuracy.", value="No", choices=["Yes", "No"])
|
92 |
+
with gr.Row():
|
93 |
return_intermediate_files = gr.Dropdown(label = "Return intermediate processing files from file preparation.", value="Yes", choices=["Yes", "No"])
|
94 |
+
save_topic_model = gr.Dropdown(label = "Save topic model to BERTopic format pkl file.", value="No", choices=["Yes", "No"])
|
95 |
|
96 |
+
# Load in data. Update column names dropdown when file uploaded
|
97 |
in_files.upload(fn=initial_file_load, inputs=[in_files], outputs=[in_colnames, in_label, data_state, output_single_text, topic_model_state, embeddings_state, data_file_name_no_ext_state, label_list_state])
|
98 |
in_colnames.change(dummy_function, in_colnames, None)
|
99 |
|
100 |
+
# Clean data
|
101 |
+
clean_btn.click(fn=pre_clean, inputs=[data_state, in_colnames, data_file_name_no_ext_state, clean_text, drop_duplicate_text, anonymise_drop], outputs=[output_single_text, output_file, data_state, data_file_name_no_ext_state], api_name="clean")
|
102 |
+
|
103 |
+
# Extract topics
|
104 |
+
topics_btn.click(fn=extract_topics, inputs=[data_state, in_files, min_docs_slider, in_colnames, max_topics_slider, candidate_topics, data_file_name_no_ext_state, label_list_state, return_intermediate_files, embedding_super_compress, low_resource_mode_opt, save_topic_model, embeddings_state, zero_shot_similarity, seed_number, calc_probs, vectoriser_state], outputs=[output_single_text, output_file, embeddings_state, data_file_name_no_ext_state, topic_model_state, docs_state, vectoriser_state], api_name="topics")
|
105 |
|
106 |
+
# Reduce outliers
|
107 |
reduce_outliers_btn.click(fn=reduce_outliers, inputs=[topic_model_state, docs_state, embeddings_state, data_file_name_no_ext_state, save_topic_model], outputs=[output_single_text, output_file, topic_model_state], api_name="reduce_outliers")
|
108 |
|
109 |
+
# Re-represent topic labels
|
110 |
+
represent_llm_btn.click(fn=represent_topics, inputs=[topic_model_state, docs_state, data_file_name_no_ext_state, low_resource_mode_opt, save_topic_model, representation_type, vectoriser_state], outputs=[output_single_text, output_file, topic_model_state], api_name="represent_llm")
|
111 |
|
112 |
+
# Save in Pytorch format
|
113 |
save_pytorch_btn.click(fn=save_as_pytorch_model, inputs=[topic_model_state, data_file_name_no_ext_state], outputs=[output_single_text, output_file])
|
114 |
|
115 |
+
# Visualise topics
|
116 |
plot_btn.click(fn=visualise_topics, inputs=[topic_model_state, data_state, data_file_name_no_ext_state, low_resource_mode_opt, embeddings_state, in_label, in_colnames, legend_label, sample_slide, visualisation_type_radio, seed_number], outputs=[vis_output_single_text, out_plot_file, plot, plot_2], api_name="plot")
|
117 |
|
118 |
+
block.queue().launch(debug=True)#, server_name="0.0.0.0", ssl_verify=False, server_port=7860)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
funcs/anonymiser.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
from spacy.cli import download
|
2 |
import spacy
|
|
|
3 |
spacy.prefer_gpu()
|
4 |
|
5 |
def spacy_model_installed(model_name):
|
@@ -21,7 +22,7 @@ def spacy_model_installed(model_name):
|
|
21 |
model_name = "en_core_web_sm"
|
22 |
spacy_model_installed(model_name)
|
23 |
|
24 |
-
spacy.load(model_name)
|
25 |
# Need to overwrite version of gradio present in Huggingface spaces as it doesn't have like buttons/avatars (Oct 2023)
|
26 |
#os.system("pip uninstall -y gradio")
|
27 |
#os.system("pip install gradio==3.50.0")
|
@@ -33,11 +34,10 @@ import base64
|
|
33 |
import time
|
34 |
|
35 |
import pandas as pd
|
36 |
-
import gradio as gr
|
37 |
|
38 |
from faker import Faker
|
39 |
|
40 |
-
from presidio_analyzer import AnalyzerEngine, BatchAnalyzerEngine
|
41 |
from presidio_anonymizer import AnonymizerEngine, BatchAnonymizerEngine
|
42 |
from presidio_anonymizer.entities import OperatorConfig
|
43 |
|
@@ -159,10 +159,24 @@ def read_file(filename):
|
|
159 |
|
160 |
def anonymise_script(df, chosen_col, anon_strat):
|
161 |
|
|
|
|
|
|
|
|
|
|
|
162 |
# DataFrame to dict
|
163 |
df_dict = pd.DataFrame(data={chosen_col:df[chosen_col].astype(str)}).to_dict(orient="list")
|
164 |
|
|
|
|
|
165 |
analyzer = AnalyzerEngine()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
batch_analyzer = BatchAnalyzerEngine(analyzer_engine=analyzer)
|
167 |
|
168 |
anonymizer = AnonymizerEngine()
|
@@ -171,12 +185,13 @@ def anonymise_script(df, chosen_col, anon_strat):
|
|
171 |
|
172 |
print("Identifying personal data")
|
173 |
analyse_tic = time.perf_counter()
|
174 |
-
analyzer_results = batch_analyzer.analyze_dict(df_dict, language="en")
|
|
|
175 |
#print(analyzer_results)
|
176 |
analyzer_results = list(analyzer_results)
|
177 |
|
178 |
analyse_toc = time.perf_counter()
|
179 |
-
analyse_time_out = f"
|
180 |
print(analyse_time_out)
|
181 |
|
182 |
# Generate a 128-bit AES key. Then encode the key using base64 to get a string representation
|
@@ -206,15 +221,34 @@ def anonymise_script(df, chosen_col, anon_strat):
|
|
206 |
if anon_strat == "encrypt": chosen_mask_config = people_encrypt_config
|
207 |
elif anon_strat == "fake_first_name": chosen_mask_config = fake_first_name_config
|
208 |
|
209 |
-
# I think in general people will want to keep date / times
|
210 |
-
keep_date_config = eval('{"DATE_TIME": OperatorConfig("keep")}')
|
211 |
|
212 |
-
combined_config = {**chosen_mask_config, **keep_date_config}
|
|
|
213 |
combined_config
|
214 |
|
|
|
215 |
anonymizer_results = batch_anonymizer.anonymize_dict(analyzer_results, operators=combined_config)
|
216 |
|
217 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
|
219 |
# Create reporting message
|
220 |
out_message = "Successfully anonymised"
|
@@ -222,7 +256,7 @@ def anonymise_script(df, chosen_col, anon_strat):
|
|
222 |
if anon_strat == "encrypt":
|
223 |
out_message = out_message + ". Your decryption key is " + key_string + "."
|
224 |
|
225 |
-
return
|
226 |
|
227 |
def do_anonymise(in_file, anon_strat, chosen_cols):
|
228 |
|
|
|
1 |
from spacy.cli import download
|
2 |
import spacy
|
3 |
+
from funcs.presidio_analyzer_custom import analyze_dict
|
4 |
spacy.prefer_gpu()
|
5 |
|
6 |
def spacy_model_installed(model_name):
|
|
|
22 |
model_name = "en_core_web_sm"
|
23 |
spacy_model_installed(model_name)
|
24 |
|
25 |
+
#spacy.load(model_name)
|
26 |
# Need to overwrite version of gradio present in Huggingface spaces as it doesn't have like buttons/avatars (Oct 2023)
|
27 |
#os.system("pip uninstall -y gradio")
|
28 |
#os.system("pip install gradio==3.50.0")
|
|
|
34 |
import time
|
35 |
|
36 |
import pandas as pd
|
|
|
37 |
|
38 |
from faker import Faker
|
39 |
|
40 |
+
from presidio_analyzer import AnalyzerEngine, BatchAnalyzerEngine, PatternRecognizer
|
41 |
from presidio_anonymizer import AnonymizerEngine, BatchAnonymizerEngine
|
42 |
from presidio_anonymizer.entities import OperatorConfig
|
43 |
|
|
|
159 |
|
160 |
def anonymise_script(df, chosen_col, anon_strat):
|
161 |
|
162 |
+
#print(df.shape)
|
163 |
+
|
164 |
+
#df_chosen_col_mask = (df[chosen_col].isnull()) | (df[chosen_col].str.strip() == "")
|
165 |
+
#print("Length of input series blank at start is: ", df_chosen_col_mask.value_counts())
|
166 |
+
|
167 |
# DataFrame to dict
|
168 |
df_dict = pd.DataFrame(data={chosen_col:df[chosen_col].astype(str)}).to_dict(orient="list")
|
169 |
|
170 |
+
|
171 |
+
|
172 |
analyzer = AnalyzerEngine()
|
173 |
+
|
174 |
+
# Add titles to analyzer list
|
175 |
+
titles_recognizer = PatternRecognizer(supported_entity="TITLE",
|
176 |
+
deny_list=["Mr","Mrs","Miss", "Ms", "mr", "mrs", "miss", "ms"])
|
177 |
+
|
178 |
+
analyzer.registry.add_recognizer(titles_recognizer)
|
179 |
+
|
180 |
batch_analyzer = BatchAnalyzerEngine(analyzer_engine=analyzer)
|
181 |
|
182 |
anonymizer = AnonymizerEngine()
|
|
|
185 |
|
186 |
print("Identifying personal data")
|
187 |
analyse_tic = time.perf_counter()
|
188 |
+
#analyzer_results = batch_analyzer.analyze_dict(df_dict, language="en")
|
189 |
+
analyzer_results = analyze_dict(batch_analyzer, df_dict, language="en")
|
190 |
#print(analyzer_results)
|
191 |
analyzer_results = list(analyzer_results)
|
192 |
|
193 |
analyse_toc = time.perf_counter()
|
194 |
+
analyse_time_out = f"Analysing the text took {analyse_toc - analyse_tic:0.1f} seconds."
|
195 |
print(analyse_time_out)
|
196 |
|
197 |
# Generate a 128-bit AES key. Then encode the key using base64 to get a string representation
|
|
|
221 |
if anon_strat == "encrypt": chosen_mask_config = people_encrypt_config
|
222 |
elif anon_strat == "fake_first_name": chosen_mask_config = fake_first_name_config
|
223 |
|
224 |
+
# I think in general people will want to keep date / times - NOT FOR TOPIC MODELLING
|
225 |
+
#keep_date_config = eval('{"DATE_TIME": OperatorConfig("keep")}')
|
226 |
|
227 |
+
#combined_config = {**chosen_mask_config, **keep_date_config}
|
228 |
+
combined_config = {**chosen_mask_config}#, **keep_date_config}
|
229 |
combined_config
|
230 |
|
231 |
+
print("Anonymising personal data")
|
232 |
anonymizer_results = batch_anonymizer.anonymize_dict(analyzer_results, operators=combined_config)
|
233 |
|
234 |
+
#print(anonymizer_results)
|
235 |
+
|
236 |
+
scrubbed_df = pd.DataFrame(data={chosen_col:anonymizer_results[chosen_col]})
|
237 |
+
|
238 |
+
scrubbed_series = scrubbed_df[chosen_col]
|
239 |
+
|
240 |
+
#print(scrubbed_series[0:6])
|
241 |
+
|
242 |
+
#print("Length of output series is: ", len(scrubbed_series))
|
243 |
+
#print("Length of input series at end is: ", len(df[chosen_col]))
|
244 |
+
|
245 |
+
|
246 |
+
#scrubbed_values_mask = (scrubbed_series.isnull()) | (scrubbed_series.str.strip() == "")
|
247 |
+
#df_chosen_col_mask = (df[chosen_col].isnull()) | (df[chosen_col].str.strip() == "")
|
248 |
+
|
249 |
+
#print("Length of input series blank at end is: ", df_chosen_col_mask.value_counts())
|
250 |
+
#print("Length of output series blank is: ", scrubbed_values_mask.value_counts())
|
251 |
+
|
252 |
|
253 |
# Create reporting message
|
254 |
out_message = "Successfully anonymised"
|
|
|
256 |
if anon_strat == "encrypt":
|
257 |
out_message = out_message + ". Your decryption key is " + key_string + "."
|
258 |
|
259 |
+
return scrubbed_series, out_message
|
260 |
|
261 |
def do_anonymise(in_file, anon_strat, chosen_cols):
|
262 |
|
funcs/clean_funcs.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import string
|
3 |
+
import polars as pl
|
4 |
+
import gradio as gr
|
5 |
+
|
6 |
+
# Adding custom words to the stopwords
|
7 |
+
custom_words = []
|
8 |
+
my_stop_words = custom_words
|
9 |
+
|
10 |
+
# #### Some of my cleaning functions
|
11 |
+
email_start_pattern_regex = r'.*importance:|.*subject:'
|
12 |
+
email_end_pattern_regex = r'kind regards.*|many thanks.*|sincerely.*'
|
13 |
+
html_pattern_regex = r'<.*?>|&([a-z0-9]+|#[0-9]{1,6}|#x[0-9a-f]{1,6});|\xa0| '
|
14 |
+
email_pattern_regex = r'\S*@\S*\s?'
|
15 |
+
num_pattern_regex = r'[0-9]+'
|
16 |
+
nums_three_more_regex = r'\b[0-9]{3,}\b|\b[0-9]+\s[0-9]+\b'
|
17 |
+
postcode_pattern_regex = r'(\b(?:[A-Z][A-HJ-Y]?[0-9][0-9A-Z]? ?[0-9][A-Z]{2})|((GIR ?0A{2})\b$)|(?:[A-Z][A-HJ-Y]?[0-9][0-9A-Z]? ?[0-9]{1}?)$)|(\b(?:[A-Z][A-HJ-Y]?[0-9][0-9A-Z]?)\b$)'
|
18 |
+
warning_pattern_regex = r'caution: this email originated from outside of the organization. do not click links or open attachments unless you recognize the sender and know the content is safe.'
|
19 |
+
nbsp_pattern_regex = r' '
|
20 |
+
|
21 |
+
# Pre-compiling the regular expressions for efficiency (not actually used)
|
22 |
+
# email_start_pattern = re.compile(email_start_pattern_regex)
|
23 |
+
# email_end_pattern = re.compile(email_end_pattern_regex)
|
24 |
+
# html_pattern = re.compile(html_pattern_regex)
|
25 |
+
# email_pattern = re.compile(email_end_pattern_regex)
|
26 |
+
# num_pattern = re.compile(num_pattern_regex)
|
27 |
+
# nums_three_more_regex_pattern = re.compile(nums_three_more_regex)
|
28 |
+
# postcode_pattern = re.compile(postcode_pattern_regex)
|
29 |
+
# warning_pattern = re.compile(warning_pattern_regex)
|
30 |
+
# nbsp_pattern = re.compile(nbsp_pattern_regex)
|
31 |
+
|
32 |
+
def initial_clean(texts , progress=gr.Progress()):
|
33 |
+
texts = pl.Series(texts)
|
34 |
+
text = texts.str.replace_all(html_pattern_regex, '')
|
35 |
+
text = text.str.replace_all(email_pattern_regex, '')
|
36 |
+
text = text.str.replace_all(nums_three_more_regex, '')
|
37 |
+
text = text.str.replace_all(postcode_pattern_regex, '')
|
38 |
+
|
39 |
+
text = text.to_list()
|
40 |
+
|
41 |
+
return text
|
42 |
+
|
43 |
+
def remove_hyphens(text_text):
|
44 |
+
return re.sub(r'(\w+)-(\w+)-?(\w)?', r'\1 \2 \3', text_text)
|
45 |
+
|
46 |
+
|
47 |
+
def remove_characters_after_tokenization(tokens):
|
48 |
+
pattern = re.compile('[{}]'.format(re.escape(string.punctuation)))
|
49 |
+
filtered_tokens = filter(None, [pattern.sub('', token) for token in tokens])
|
50 |
+
return filtered_tokens
|
51 |
+
|
52 |
+
def convert_to_lowercase(tokens):
|
53 |
+
return [token.lower() for token in tokens if token.isalpha()]
|
54 |
+
|
55 |
+
def remove_short_tokens(tokens):
|
56 |
+
return [token for token in tokens if len(token) > 3]
|
57 |
+
|
58 |
+
|
59 |
+
def remove_dups_text(data_samples_ready, data_samples_clean, data_samples):
|
60 |
+
# Identify duplicates in the data: https://stackoverflow.com/questions/44191465/efficiently-identify-duplicates-in-large-list-500-000
|
61 |
+
# Only identifies the second duplicate
|
62 |
+
|
63 |
+
seen = set()
|
64 |
+
dups = []
|
65 |
+
|
66 |
+
for i, doi in enumerate(data_samples_ready):
|
67 |
+
if doi not in seen:
|
68 |
+
seen.add(doi)
|
69 |
+
else:
|
70 |
+
dups.append(i)
|
71 |
+
#data_samples_ready[dupes[0:]]
|
72 |
+
|
73 |
+
# To see a specific duplicated value you know the position of
|
74 |
+
#matching = [s for s in data_samples_ready if data_samples_ready[83] in s]
|
75 |
+
#matching
|
76 |
+
|
77 |
+
# Remove duplicates only (keep first instance)
|
78 |
+
#data_samples_ready = list( dict.fromkeys(data_samples_ready) ) # This way would keep one version of the duplicates
|
79 |
+
|
80 |
+
### Remove all duplicates including original instance
|
81 |
+
|
82 |
+
# Identify ALL duplicates including initial values
|
83 |
+
# https://stackoverflow.com/questions/11236006/identify-duplicate-values-in-a-list-in-python
|
84 |
+
|
85 |
+
from collections import defaultdict
|
86 |
+
D = defaultdict(list)
|
87 |
+
for i,item in enumerate(data_samples_ready):
|
88 |
+
D[item].append(i)
|
89 |
+
D = {k:v for k,v in D.items() if len(v)>1}
|
90 |
+
|
91 |
+
# https://stackoverflow.com/questions/952914/how-to-make-a-flat-list-out-of-a-list-of-lists
|
92 |
+
L = list(D.values())
|
93 |
+
flat_list_dups = [item for sublist in L for item in sublist]
|
94 |
+
|
95 |
+
# https://stackoverflow.com/questions/11303225/how-to-remove-multiple-indexes-from-a-list-at-the-same-time
|
96 |
+
for index in sorted(flat_list_dups, reverse=True):
|
97 |
+
del data_samples_ready[index]
|
98 |
+
del data_samples_clean[index]
|
99 |
+
del data_samples[index]
|
100 |
+
|
101 |
+
# Remove blanks
|
102 |
+
data_samples_ready = [i for i in data_samples_ready if i]
|
103 |
+
data_samples_clean = [i for i in data_samples_clean if i]
|
104 |
+
data_samples = [i for i in data_samples if i]
|
105 |
+
|
106 |
+
return data_samples_ready, data_samples_clean, flat_list_dups, data_samples
|
107 |
+
|
funcs/helper_functions.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import os
|
2 |
import zipfile
|
3 |
import re
|
@@ -12,6 +13,30 @@ from datetime import datetime
|
|
12 |
today = datetime.now().strftime("%d%m%Y")
|
13 |
today_rev = datetime.now().strftime("%Y%m%d")
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
def detect_file_type(filename):
|
17 |
"""Detect the file type based on its extension."""
|
@@ -189,7 +214,10 @@ def save_topic_outputs(topic_model, data_file_name_no_ext, output_list, docs, sa
|
|
189 |
doc_dets.to_csv(doc_det_output_name)
|
190 |
output_list.append(doc_det_output_name)
|
191 |
|
192 |
-
|
|
|
|
|
|
|
193 |
output_text = "Topics: " + topics_text_out_str
|
194 |
|
195 |
# Save topic model to file
|
|
|
1 |
+
import sys
|
2 |
import os
|
3 |
import zipfile
|
4 |
import re
|
|
|
13 |
today = datetime.now().strftime("%d%m%Y")
|
14 |
today_rev = datetime.now().strftime("%Y%m%d")
|
15 |
|
16 |
+
# Log terminal output: https://github.com/gradio-app/gradio/issues/2362
|
17 |
+
class Logger:
|
18 |
+
def __init__(self, filename):
|
19 |
+
self.terminal = sys.stdout
|
20 |
+
self.log = open(filename, "w")
|
21 |
+
|
22 |
+
def write(self, message):
|
23 |
+
self.terminal.write(message)
|
24 |
+
self.log.write(message)
|
25 |
+
|
26 |
+
def flush(self):
|
27 |
+
self.terminal.flush()
|
28 |
+
self.log.flush()
|
29 |
+
|
30 |
+
def isatty(self):
|
31 |
+
return False
|
32 |
+
|
33 |
+
sys.stdout = Logger("output.log")
|
34 |
+
|
35 |
+
def read_logs():
|
36 |
+
sys.stdout.flush()
|
37 |
+
with open("output.log", "r") as f:
|
38 |
+
return f.read()
|
39 |
+
|
40 |
|
41 |
def detect_file_type(filename):
|
42 |
"""Detect the file type based on its extension."""
|
|
|
214 |
doc_dets.to_csv(doc_det_output_name)
|
215 |
output_list.append(doc_det_output_name)
|
216 |
|
217 |
+
if "CustomName" in topic_dets.columns:
|
218 |
+
topics_text_out_str = str(topic_dets["CustomName"])
|
219 |
+
else:
|
220 |
+
topics_text_out_str = str(topic_dets["Name"])
|
221 |
output_text = "Topics: " + topics_text_out_str
|
222 |
|
223 |
# Save topic model to file
|
funcs/presidio_analyzer_custom.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from typing import List, Iterable, Dict, Union, Any, Optional, Iterator, Tuple
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
from presidio_analyzer import DictAnalyzerResult, RecognizerResult, AnalyzerEngine
|
6 |
+
from presidio_analyzer.nlp_engine import NlpArtifacts
|
7 |
+
|
8 |
+
def analyze_iterator_custom(
|
9 |
+
self,
|
10 |
+
texts: Iterable[Union[str, bool, float, int]],
|
11 |
+
language: str,
|
12 |
+
list_length:int,
|
13 |
+
progress=gr.Progress(),
|
14 |
+
**kwargs,
|
15 |
+
) -> List[List[RecognizerResult]]:
|
16 |
+
"""
|
17 |
+
Analyze an iterable of strings.
|
18 |
+
|
19 |
+
:param texts: An list containing strings to be analyzed.
|
20 |
+
:param language: Input language
|
21 |
+
:param list_length: Length of the input list.
|
22 |
+
:param kwargs: Additional parameters for the `AnalyzerEngine.analyze` method.
|
23 |
+
"""
|
24 |
+
|
25 |
+
# validate types
|
26 |
+
texts = self._validate_types(texts)
|
27 |
+
|
28 |
+
# Process the texts as batch for improved performance
|
29 |
+
nlp_artifacts_batch: Iterator[
|
30 |
+
Tuple[str, NlpArtifacts]
|
31 |
+
] = self.analyzer_engine.nlp_engine.process_batch(
|
32 |
+
texts=texts, language=language
|
33 |
+
)
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
list_results = []
|
38 |
+
for text, nlp_artifacts in progress.tqdm(nlp_artifacts_batch, total = list_length, desc = "Analysing text for personal information", unit = "rows"):
|
39 |
+
results = self.analyzer_engine.analyze(
|
40 |
+
text=str(text), nlp_artifacts=nlp_artifacts, language=language, **kwargs
|
41 |
+
)
|
42 |
+
|
43 |
+
list_results.append(results)
|
44 |
+
|
45 |
+
return list_results
|
46 |
+
|
47 |
+
def analyze_dict(
|
48 |
+
self,
|
49 |
+
input_dict: Dict[str, Union[Any, Iterable[Any]]],
|
50 |
+
language: str,
|
51 |
+
keys_to_skip: Optional[List[str]] = None,
|
52 |
+
**kwargs,
|
53 |
+
) -> Iterator[DictAnalyzerResult]:
|
54 |
+
"""
|
55 |
+
Analyze a dictionary of keys (strings) and values/iterable of values.
|
56 |
+
|
57 |
+
Non-string values are returned as is.
|
58 |
+
|
59 |
+
:param input_dict: The input dictionary for analysis
|
60 |
+
:param language: Input language
|
61 |
+
:param keys_to_skip: Keys to ignore during analysis
|
62 |
+
:param kwargs: Additional keyword arguments
|
63 |
+
for the `AnalyzerEngine.analyze` method.
|
64 |
+
Use this to pass arguments to the analyze method,
|
65 |
+
such as `ad_hoc_recognizers`, `context`, `return_decision_process`.
|
66 |
+
See `AnalyzerEngine.analyze` for the full list.
|
67 |
+
"""
|
68 |
+
|
69 |
+
context = []
|
70 |
+
if "context" in kwargs:
|
71 |
+
context = kwargs["context"]
|
72 |
+
del kwargs["context"]
|
73 |
+
|
74 |
+
if not keys_to_skip:
|
75 |
+
keys_to_skip = []
|
76 |
+
|
77 |
+
|
78 |
+
for key, value in input_dict.items():
|
79 |
+
if not value or key in keys_to_skip:
|
80 |
+
yield DictAnalyzerResult(key=key, value=value, recognizer_results=[])
|
81 |
+
continue # skip this key as requested
|
82 |
+
|
83 |
+
# Add the key as an additional context
|
84 |
+
specific_context = context[:]
|
85 |
+
specific_context.append(key)
|
86 |
+
|
87 |
+
if type(value) in (str, int, bool, float):
|
88 |
+
results: List[RecognizerResult] = self.analyzer_engine.analyze(
|
89 |
+
text=str(value), language=language, context=[key], **kwargs
|
90 |
+
)
|
91 |
+
elif isinstance(value, dict):
|
92 |
+
new_keys_to_skip = self._get_nested_keys_to_skip(key, keys_to_skip)
|
93 |
+
results = self.analyze_dict(
|
94 |
+
input_dict=value,
|
95 |
+
language=language,
|
96 |
+
context=specific_context,
|
97 |
+
keys_to_skip=new_keys_to_skip,
|
98 |
+
**kwargs,
|
99 |
+
)
|
100 |
+
elif isinstance(value, Iterable):
|
101 |
+
# Recursively iterate nested dicts
|
102 |
+
list_length = len(value)
|
103 |
+
|
104 |
+
results: List[List[RecognizerResult]] = analyze_iterator_custom(self,
|
105 |
+
texts=value,
|
106 |
+
language=language,
|
107 |
+
context=specific_context,
|
108 |
+
list_length=list_length,
|
109 |
+
**kwargs,
|
110 |
+
)
|
111 |
+
else:
|
112 |
+
raise ValueError(f"type {type(value)} is unsupported.")
|
113 |
+
|
114 |
+
yield DictAnalyzerResult(key=key, value=value, recognizer_results=results)
|
funcs/representation_model.py
CHANGED
@@ -1,13 +1,11 @@
|
|
1 |
import os
|
2 |
-
#from ctransformers import AutoModelForCausalLM
|
3 |
-
#from transformers import AutoTokenizer, pipeline
|
4 |
from bertopic.representation import LlamaCPP
|
5 |
from llama_cpp import Llama
|
6 |
from pydantic import BaseModel
|
7 |
import torch.cuda
|
8 |
-
from huggingface_hub import hf_hub_download
|
9 |
|
10 |
-
from bertopic.representation import KeyBERTInspired, MaximalMarginalRelevance,
|
11 |
from funcs.prompts import capybara_prompt, capybara_start, open_hermes_prompt, open_hermes_start, stablelm_prompt, stablelm_start
|
12 |
|
13 |
random_seed = 42
|
@@ -16,8 +14,6 @@ chosen_prompt = open_hermes_prompt # stablelm_prompt
|
|
16 |
chosen_start_tag = open_hermes_start # stablelm_start
|
17 |
|
18 |
|
19 |
-
|
20 |
-
|
21 |
# Currently set n_gpu_layers to 0 even with cuda due to persistent bugs in implementation with cuda
|
22 |
if torch.cuda.is_available():
|
23 |
torch_device = "gpu"
|
@@ -46,7 +42,7 @@ reset: bool = True
|
|
46 |
stream: bool = False
|
47 |
n_threads: int = n_threads
|
48 |
n_batch:int = 256
|
49 |
-
n_ctx:int = 4096
|
50 |
sample:bool = True
|
51 |
trust_remote_code:bool =True
|
52 |
|
@@ -90,7 +86,9 @@ llm_config = LLamacppInitConfigGpu(last_n_tokens_size=last_n_tokens_size,
|
|
90 |
# KeyBERT
|
91 |
keybert = KeyBERTInspired(random_state=random_seed)
|
92 |
# MMR
|
93 |
-
mmr = MaximalMarginalRelevance(diversity=0.
|
|
|
|
|
94 |
|
95 |
# Find model file
|
96 |
def find_model_file(hf_model_name, hf_model_file, search_folder):
|
@@ -135,15 +133,12 @@ def find_model_file(hf_model_name, hf_model_file, search_folder):
|
|
135 |
return found_file
|
136 |
|
137 |
|
138 |
-
def create_representation_model(
|
139 |
|
140 |
-
if
|
|
|
141 |
# Use llama.cpp to load in model
|
142 |
|
143 |
-
# This was for testing on systems without a HF_HOME env variable
|
144 |
-
#os.unsetenv("HF_HOME")
|
145 |
-
|
146 |
-
#if "HF_HOME" in os.environ:
|
147 |
# del os.environ["HF_HOME"]
|
148 |
|
149 |
# Check for HF_HOME environment variable and supply a default value if it's not found (typical location for huggingface models)
|
@@ -161,22 +156,27 @@ def create_representation_model(create_llm_topic_labels, llm_config, hf_model_na
|
|
161 |
|
162 |
found_file = find_model_file(hf_model_name, hf_model_file, hf_home_value)
|
163 |
|
164 |
-
llm = Llama(model_path=found_file, stop=chosen_start_tag, n_gpu_layers=llm_config.n_gpu_layers, n_ctx=llm_config.n_ctx) #**llm_config.model_dump())#
|
165 |
#print(llm.n_gpu_layers)
|
166 |
llm_model = LlamaCPP(llm, prompt=chosen_prompt)#, **gen_config.model_dump())
|
167 |
|
168 |
# All representation models
|
169 |
representation_model = {
|
170 |
-
"KeyBERT": keybert,
|
171 |
"LLM": llm_model
|
172 |
}
|
173 |
|
174 |
-
elif
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
|
181 |
# Deprecated example using CTransformers. This package is not really used anymore
|
182 |
#model = AutoModelForCausalLM.from_pretrained('NousResearch/Nous-Capybara-7B-V1.9-GGUF', model_type='mistral', model_file='Capybara-7B-V1.9-Q5_K_M.gguf', hf=True, **vars(llm_config))
|
|
|
1 |
import os
|
|
|
|
|
2 |
from bertopic.representation import LlamaCPP
|
3 |
from llama_cpp import Llama
|
4 |
from pydantic import BaseModel
|
5 |
import torch.cuda
|
6 |
+
from huggingface_hub import hf_hub_download
|
7 |
|
8 |
+
from bertopic.representation import KeyBERTInspired, MaximalMarginalRelevance, BaseRepresentation
|
9 |
from funcs.prompts import capybara_prompt, capybara_start, open_hermes_prompt, open_hermes_start, stablelm_prompt, stablelm_start
|
10 |
|
11 |
random_seed = 42
|
|
|
14 |
chosen_start_tag = open_hermes_start # stablelm_start
|
15 |
|
16 |
|
|
|
|
|
17 |
# Currently set n_gpu_layers to 0 even with cuda due to persistent bugs in implementation with cuda
|
18 |
if torch.cuda.is_available():
|
19 |
torch_device = "gpu"
|
|
|
42 |
stream: bool = False
|
43 |
n_threads: int = n_threads
|
44 |
n_batch:int = 256
|
45 |
+
n_ctx:int = 8192 #4096. # Set to 8192 just to avoid any exceeded context window issues
|
46 |
sample:bool = True
|
47 |
trust_remote_code:bool =True
|
48 |
|
|
|
86 |
# KeyBERT
|
87 |
keybert = KeyBERTInspired(random_state=random_seed)
|
88 |
# MMR
|
89 |
+
mmr = MaximalMarginalRelevance(diversity=0.5)
|
90 |
+
|
91 |
+
base_rep = BaseRepresentation()
|
92 |
|
93 |
# Find model file
|
94 |
def find_model_file(hf_model_name, hf_model_file, search_folder):
|
|
|
133 |
return found_file
|
134 |
|
135 |
|
136 |
+
def create_representation_model(representation_type, llm_config, hf_model_name, hf_model_file, chosen_start_tag, low_resource_mode):
|
137 |
|
138 |
+
if representation_type == "LLM":
|
139 |
+
print("Generating LLM representation")
|
140 |
# Use llama.cpp to load in model
|
141 |
|
|
|
|
|
|
|
|
|
142 |
# del os.environ["HF_HOME"]
|
143 |
|
144 |
# Check for HF_HOME environment variable and supply a default value if it's not found (typical location for huggingface models)
|
|
|
156 |
|
157 |
found_file = find_model_file(hf_model_name, hf_model_file, hf_home_value)
|
158 |
|
159 |
+
llm = Llama(model_path=found_file, stop=chosen_start_tag, n_gpu_layers=llm_config.n_gpu_layers, n_ctx=llm_config.n_ctx, rope_freq_scale=0.5) #**llm_config.model_dump())#
|
160 |
#print(llm.n_gpu_layers)
|
161 |
llm_model = LlamaCPP(llm, prompt=chosen_prompt)#, **gen_config.model_dump())
|
162 |
|
163 |
# All representation models
|
164 |
representation_model = {
|
|
|
165 |
"LLM": llm_model
|
166 |
}
|
167 |
|
168 |
+
elif representation_type == "KeyBERT":
|
169 |
+
print("Generating KeyBERT representation")
|
170 |
+
#representation_model = {"mmr": mmr}
|
171 |
+
representation_model = {"KeyBERT": keybert}
|
172 |
+
|
173 |
+
elif representation_type == "MMR":
|
174 |
+
print("Generating MMR representation")
|
175 |
+
representation_model = {"MMR": mmr}
|
176 |
+
|
177 |
+
else:
|
178 |
+
print("Generating default representation type")
|
179 |
+
representation_model = {"Default":base_rep}
|
180 |
|
181 |
# Deprecated example using CTransformers. This package is not really used anymore
|
182 |
#model = AutoModelForCausalLM.from_pretrained('NousResearch/Nous-Capybara-7B-V1.9-GGUF', model_type='mistral', model_file='Capybara-7B-V1.9-Q5_K_M.gguf', hf=True, **vars(llm_config))
|
funcs/topic_core_funcs.py
ADDED
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Dendrograms will not work with the latest version of scipy (1.12.0), so installing the version prior to be safe
|
2 |
+
#os.system("pip install scipy==1.11.4")
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
from datetime import datetime
|
6 |
+
import pandas as pd
|
7 |
+
import numpy as np
|
8 |
+
import time
|
9 |
+
from bertopic import BERTopic
|
10 |
+
|
11 |
+
from funcs.clean_funcs import initial_clean
|
12 |
+
from funcs.helper_functions import read_file, zip_folder, delete_files_in_folder, save_topic_outputs
|
13 |
+
from funcs.embeddings import make_or_load_embeddings
|
14 |
+
|
15 |
+
from sentence_transformers import SentenceTransformer
|
16 |
+
from sklearn.pipeline import make_pipeline
|
17 |
+
from sklearn.decomposition import TruncatedSVD
|
18 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
19 |
+
import funcs.anonymiser as anon
|
20 |
+
from umap import UMAP
|
21 |
+
|
22 |
+
from torch import cuda, backends, version
|
23 |
+
|
24 |
+
# Default seed, can be changed in number selection on options page
|
25 |
+
random_seed = 42
|
26 |
+
|
27 |
+
# Check for torch cuda
|
28 |
+
# If you want to disable cuda for testing purposes
|
29 |
+
#os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
|
30 |
+
|
31 |
+
print("Is CUDA enabled? ", cuda.is_available())
|
32 |
+
print("Is a CUDA device available on this computer?", backends.cudnn.enabled)
|
33 |
+
if cuda.is_available():
|
34 |
+
torch_device = "gpu"
|
35 |
+
print("Cuda version installed is: ", version.cuda)
|
36 |
+
low_resource_mode = "No"
|
37 |
+
#os.system("nvidia-smi")
|
38 |
+
else:
|
39 |
+
torch_device = "cpu"
|
40 |
+
low_resource_mode = "Yes"
|
41 |
+
|
42 |
+
print("Device used is: ", torch_device)
|
43 |
+
|
44 |
+
today = datetime.now().strftime("%d%m%Y")
|
45 |
+
today_rev = datetime.now().strftime("%Y%m%d")
|
46 |
+
|
47 |
+
# Load embeddings
|
48 |
+
embeddings_name = "BAAI/bge-small-en-v1.5" #"jinaai/jina-embeddings-v2-base-en"
|
49 |
+
|
50 |
+
# LLM model used for representing topics
|
51 |
+
hf_model_name = 'second-state/stablelm-2-zephyr-1.6b-GGUF' #'TheBloke/phi-2-orange-GGUF' #'NousResearch/Nous-Capybara-7B-V1.9-GGUF'
|
52 |
+
hf_model_file = 'stablelm-2-zephyr-1_6b-Q5_K_M.gguf' # 'phi-2-orange.Q5_K_M.gguf' #'Capybara-7B-V1.9-Q5_K_M.gguf'
|
53 |
+
|
54 |
+
def pre_clean(data, in_colnames, data_file_name_no_ext, clean_text, drop_duplicate_text, anonymise_drop, progress=gr.Progress(track_tqdm=True)):
|
55 |
+
|
56 |
+
output_text = ""
|
57 |
+
output_list = []
|
58 |
+
|
59 |
+
progress(0, desc = "Cleaning data")
|
60 |
+
|
61 |
+
if not in_colnames:
|
62 |
+
error_message = "Please enter one column name to use for cleaning and finding topics."
|
63 |
+
print(error_message)
|
64 |
+
return error_message, None, data_file_name_no_ext, None, None
|
65 |
+
|
66 |
+
all_tic = time.perf_counter()
|
67 |
+
|
68 |
+
output_list = []
|
69 |
+
#file_list = [string.name for string in in_files]
|
70 |
+
|
71 |
+
in_colnames_list_first = in_colnames[0]
|
72 |
+
|
73 |
+
if clean_text == "Yes":
|
74 |
+
clean_tic = time.perf_counter()
|
75 |
+
print("Starting data clean.")
|
76 |
+
|
77 |
+
data_file_name_no_ext = data_file_name_no_ext + "_clean"
|
78 |
+
|
79 |
+
data[in_colnames_list_first] = initial_clean(data[in_colnames_list_first])
|
80 |
+
|
81 |
+
clean_toc = time.perf_counter()
|
82 |
+
clean_time_out = f"Cleaning the text took {clean_toc - clean_tic:0.1f} seconds."
|
83 |
+
print(clean_time_out)
|
84 |
+
|
85 |
+
if drop_duplicate_text == "Yes":
|
86 |
+
progress(0.3, desc= "Drop duplicates - remove short texts")
|
87 |
+
|
88 |
+
data_file_name_no_ext = data_file_name_no_ext + "_dedup"
|
89 |
+
|
90 |
+
#print("Removing duplicates and short entries from data")
|
91 |
+
#print("Data shape before: ", data.shape)
|
92 |
+
data[in_colnames_list_first] = data[in_colnames_list_first].str.strip()
|
93 |
+
data = data[data[in_colnames_list_first].str.len() >= 10]
|
94 |
+
data = data.drop_duplicates(subset = in_colnames_list_first).dropna(subset= in_colnames_list_first).reset_index()
|
95 |
+
|
96 |
+
#print("Data shape after duplicate/null removal: ", data.shape)
|
97 |
+
|
98 |
+
if anonymise_drop == "Yes":
|
99 |
+
progress(0.6, desc= "Anonymising data")
|
100 |
+
|
101 |
+
data_file_name_no_ext = data_file_name_no_ext + "_anon"
|
102 |
+
|
103 |
+
anon_tic = time.perf_counter()
|
104 |
+
|
105 |
+
data_anon_col, anonymisation_success = anon.anonymise_script(data, in_colnames_list_first, anon_strat="redact")
|
106 |
+
|
107 |
+
data[in_colnames_list_first] = data_anon_col
|
108 |
+
|
109 |
+
print(anonymisation_success)
|
110 |
+
|
111 |
+
anon_toc = time.perf_counter()
|
112 |
+
time_out = f"Anonymising text took {anon_toc - anon_tic:0.1f} seconds"
|
113 |
+
|
114 |
+
out_data_name = data_file_name_no_ext + "_" + today_rev + ".csv"
|
115 |
+
data.to_csv(out_data_name)
|
116 |
+
output_list.append(out_data_name)
|
117 |
+
|
118 |
+
all_toc = time.perf_counter()
|
119 |
+
time_out = f"All processes took {all_toc - all_tic:0.1f} seconds."
|
120 |
+
print(time_out)
|
121 |
+
|
122 |
+
output_text = "Data clean completed."
|
123 |
+
|
124 |
+
return output_text, output_list, data, data_file_name_no_ext
|
125 |
+
|
126 |
+
def extract_topics(data, in_files, min_docs_slider, in_colnames, max_topics_slider, candidate_topics, data_file_name_no_ext, custom_labels_df, return_intermediate_files, embeddings_super_compress, low_resource_mode, save_topic_model, embeddings_out, zero_shot_similarity, random_seed, calc_probs, vectoriser_state, progress=gr.Progress(track_tqdm=True)):
|
127 |
+
|
128 |
+
all_tic = time.perf_counter()
|
129 |
+
|
130 |
+
progress(0, desc= "Loading data")
|
131 |
+
|
132 |
+
output_list = []
|
133 |
+
file_list = [string.name for string in in_files]
|
134 |
+
|
135 |
+
if calc_probs == "No":
|
136 |
+
calc_probs = False
|
137 |
+
|
138 |
+
elif calc_probs == "Yes":
|
139 |
+
print("Calculating all probabilities.")
|
140 |
+
calc_probs = True
|
141 |
+
|
142 |
+
if not in_colnames:
|
143 |
+
error_message = "Please enter one column name to use for cleaning and finding topics."
|
144 |
+
print(error_message)
|
145 |
+
return error_message, None, data_file_name_no_ext, embeddings_out, None, None
|
146 |
+
|
147 |
+
|
148 |
+
|
149 |
+
in_colnames_list_first = in_colnames[0]
|
150 |
+
|
151 |
+
docs = list(data[in_colnames_list_first])
|
152 |
+
|
153 |
+
# Check if embeddings are being loaded in
|
154 |
+
progress(0.2, desc= "Loading/creating embeddings")
|
155 |
+
|
156 |
+
print("Low resource mode: ", low_resource_mode)
|
157 |
+
|
158 |
+
if low_resource_mode == "No":
|
159 |
+
print("Using high resource BGE transformer model")
|
160 |
+
|
161 |
+
embedding_model = SentenceTransformer(embeddings_name)
|
162 |
+
|
163 |
+
# UMAP model uses Bertopic defaults
|
164 |
+
umap_model = UMAP(n_neighbors=15, n_components=5, min_dist=0.0, metric='cosine', low_memory=False, random_state=random_seed)
|
165 |
+
|
166 |
+
elif low_resource_mode == "Yes":
|
167 |
+
print("Choosing low resource TF-IDF model.")
|
168 |
+
|
169 |
+
embedding_model_pipe = make_pipeline(
|
170 |
+
TfidfVectorizer(),
|
171 |
+
TruncatedSVD(100)
|
172 |
+
)
|
173 |
+
embedding_model = embedding_model_pipe
|
174 |
+
|
175 |
+
umap_model = TruncatedSVD(n_components=5, random_state=random_seed)
|
176 |
+
|
177 |
+
embeddings_out = make_or_load_embeddings(docs, file_list, embeddings_out, embedding_model, embeddings_super_compress, low_resource_mode)
|
178 |
+
|
179 |
+
# This is saved as a Gradio state object
|
180 |
+
vectoriser_model = vectoriser_state
|
181 |
+
|
182 |
+
progress(0.3, desc= "Embeddings loaded. Creating BERTopic model")
|
183 |
+
|
184 |
+
fail_error_message = "Topic model creation failed. If you have a small dataset, try reducing minimum documents per topic."
|
185 |
+
|
186 |
+
if not candidate_topics:
|
187 |
+
|
188 |
+
try:
|
189 |
+
|
190 |
+
topic_model = BERTopic( embedding_model=embedding_model,
|
191 |
+
vectorizer_model=vectoriser_model,
|
192 |
+
umap_model=umap_model,
|
193 |
+
min_topic_size = min_docs_slider,
|
194 |
+
nr_topics = max_topics_slider,
|
195 |
+
calculate_probabilities=calc_probs,
|
196 |
+
verbose = True)
|
197 |
+
|
198 |
+
assigned_topics, probs = topic_model.fit_transform(docs, embeddings_out)
|
199 |
+
|
200 |
+
except:
|
201 |
+
print(fail_error_message)
|
202 |
+
|
203 |
+
return fail_error_message, output_list, embeddings_out, data_file_name_no_ext, None, docs, vectoriser_model
|
204 |
+
|
205 |
+
|
206 |
+
# Do this if you have pre-defined topics
|
207 |
+
else:
|
208 |
+
if low_resource_mode == "Yes":
|
209 |
+
error_message = "Zero shot topic modelling currently not compatible with low-resource embeddings. Please change this option to 'No' on the options tab and retry."
|
210 |
+
print(error_message)
|
211 |
+
|
212 |
+
return error_message, output_list, embeddings_out, data_file_name_no_ext, None, docs
|
213 |
+
|
214 |
+
zero_shot_topics = read_file(candidate_topics.name)
|
215 |
+
zero_shot_topics_lower = list(zero_shot_topics.iloc[:, 0].str.lower())
|
216 |
+
|
217 |
+
|
218 |
+
try:
|
219 |
+
topic_model = BERTopic( embedding_model=embedding_model, #embedding_model_pipe, # for Jina
|
220 |
+
vectorizer_model=vectoriser_model,
|
221 |
+
umap_model=umap_model,
|
222 |
+
min_topic_size = min_docs_slider,
|
223 |
+
nr_topics = max_topics_slider,
|
224 |
+
zeroshot_topic_list = zero_shot_topics_lower,
|
225 |
+
zeroshot_min_similarity = zero_shot_similarity, # 0.7
|
226 |
+
calculate_probabilities=calc_probs,
|
227 |
+
verbose = True)
|
228 |
+
|
229 |
+
assigned_topics, probs = topic_model.fit_transform(docs, embeddings_out)
|
230 |
+
|
231 |
+
except:
|
232 |
+
print(fail_error_message)
|
233 |
+
|
234 |
+
return fail_error_message, output_list, embeddings_out, data_file_name_no_ext, None, docs, vectoriser_model
|
235 |
+
|
236 |
+
# For some reason, zero topic modelling exports assigned topics as a np.array instead of a list. Converting it back here.
|
237 |
+
if isinstance(assigned_topics, np.ndarray):
|
238 |
+
assigned_topics = assigned_topics.tolist()
|
239 |
+
|
240 |
+
# Zero shot modelling is a model merge, which wipes the c_tf_idf part of the resulting model completely. To get hierarchical modelling to work, we need to recreate this part of the model with the CountVectorizer options used to create the initial model. Since with zero shot, we are merging two models that have exactly the same set of documents, the vocubulary should be the same, and so recreating the cf_tf_idf component in this way shouldn't be a problem. Discussion here, and below based on Maarten's suggested code: https://github.com/MaartenGr/BERTopic/issues/1700
|
241 |
+
|
242 |
+
doc_dets = topic_model.get_document_info(docs)
|
243 |
+
|
244 |
+
documents_per_topic = doc_dets.groupby(['Topic'], as_index=False).agg({'Document': ' '.join})
|
245 |
+
|
246 |
+
# Assign CountVectorizer to merged model
|
247 |
+
|
248 |
+
topic_model.vectorizer_model = vectoriser_model
|
249 |
+
|
250 |
+
# Re-calculate c-TF-IDF
|
251 |
+
c_tf_idf, _ = topic_model._c_tf_idf(documents_per_topic)
|
252 |
+
topic_model.c_tf_idf_ = c_tf_idf
|
253 |
+
|
254 |
+
if not assigned_topics:
|
255 |
+
# Handle the empty array case
|
256 |
+
return "No topics found.", output_list, embeddings_out, data_file_name_no_ext, topic_model, docs
|
257 |
+
|
258 |
+
else:
|
259 |
+
print("Topic model created.")
|
260 |
+
|
261 |
+
# Replace current topic labels if new ones loaded in
|
262 |
+
if not custom_labels_df.empty:
|
263 |
+
#custom_label_list = list(custom_labels_df.iloc[:,0])
|
264 |
+
custom_label_list = [label.replace("\n", "") for label in custom_labels_df.iloc[:,0]]
|
265 |
+
|
266 |
+
topic_model.set_topic_labels(custom_label_list)
|
267 |
+
|
268 |
+
print("Custom topics: ", topic_model.custom_labels_)
|
269 |
+
|
270 |
+
# Outputs
|
271 |
+
output_list, output_text = save_topic_outputs(topic_model, data_file_name_no_ext, output_list, docs, save_topic_model)
|
272 |
+
|
273 |
+
# If you want to save your embedding files
|
274 |
+
if return_intermediate_files == "Yes":
|
275 |
+
print("Saving embeddings to file")
|
276 |
+
if low_resource_mode == "Yes":
|
277 |
+
embeddings_file_name = data_file_name_no_ext + '_' + 'tfidf_embeddings.npz'
|
278 |
+
else:
|
279 |
+
if embeddings_super_compress == "No":
|
280 |
+
embeddings_file_name = data_file_name_no_ext + '_' + 'bge_embeddings.npz'
|
281 |
+
else:
|
282 |
+
embeddings_file_name = data_file_name_no_ext + '_' + 'bge_embeddings_compress.npz'
|
283 |
+
|
284 |
+
np.savez_compressed(embeddings_file_name, embeddings_out)
|
285 |
+
|
286 |
+
output_list.append(embeddings_file_name)
|
287 |
+
|
288 |
+
all_toc = time.perf_counter()
|
289 |
+
time_out = f"All processes took {all_toc - all_tic:0.1f} seconds."
|
290 |
+
print(time_out)
|
291 |
+
|
292 |
+
return output_text, output_list, embeddings_out, data_file_name_no_ext, topic_model, docs, vectoriser_model
|
293 |
+
|
294 |
+
def reduce_outliers(topic_model, docs, embeddings_out, data_file_name_no_ext, save_topic_model, progress=gr.Progress(track_tqdm=True)):
|
295 |
+
|
296 |
+
progress(0, desc= "Preparing data")
|
297 |
+
|
298 |
+
output_list = []
|
299 |
+
|
300 |
+
all_tic = time.perf_counter()
|
301 |
+
|
302 |
+
assigned_topics, probs = topic_model.fit_transform(docs, embeddings_out)
|
303 |
+
|
304 |
+
if isinstance(assigned_topics, np.ndarray):
|
305 |
+
assigned_topics = assigned_topics.tolist()
|
306 |
+
|
307 |
+
|
308 |
+
# Reduce outliers if required, then update representation
|
309 |
+
progress(0.2, desc= "Reducing outliers")
|
310 |
+
print("Reducing outliers.")
|
311 |
+
# Calculate the c-TF-IDF representation for each outlier document and find the best matching c-TF-IDF topic representation using cosine similarity.
|
312 |
+
assigned_topics = topic_model.reduce_outliers(docs, assigned_topics, strategy="embeddings")
|
313 |
+
# Then, update the topics to the ones that considered the new data
|
314 |
+
|
315 |
+
print("Finished reducing outliers.")
|
316 |
+
|
317 |
+
progress(0.7, desc= "Replacing topic names with LLMs if necessary")
|
318 |
+
|
319 |
+
topic_dets = topic_model.get_topic_info()
|
320 |
+
|
321 |
+
# Replace original labels with LLM labels
|
322 |
+
if "LLM" in topic_model.get_topic_info().columns:
|
323 |
+
llm_labels = [label[0][0].split("\n")[0] for label in topic_model.get_topics(full=True)["LLM"].values()]
|
324 |
+
topic_model.set_topic_labels(llm_labels)
|
325 |
+
else:
|
326 |
+
topic_model.set_topic_labels(list(topic_dets["Name"]))
|
327 |
+
|
328 |
+
# Outputs
|
329 |
+
progress(0.9, desc= "Saving to file")
|
330 |
+
output_list, output_text = save_topic_outputs(topic_model, data_file_name_no_ext, output_list, docs, save_topic_model)
|
331 |
+
|
332 |
+
all_toc = time.perf_counter()
|
333 |
+
time_out = f"All processes took {all_toc - all_tic:0.1f} seconds"
|
334 |
+
print(time_out)
|
335 |
+
|
336 |
+
return output_text, output_list, topic_model
|
337 |
+
|
338 |
+
def represent_topics(topic_model, docs, data_file_name_no_ext, low_resource_mode, save_topic_model, representation_type, vectoriser_model, progress=gr.Progress(track_tqdm=True)):
|
339 |
+
from funcs.representation_model import create_representation_model, llm_config, chosen_start_tag
|
340 |
+
|
341 |
+
output_list = []
|
342 |
+
|
343 |
+
all_tic = time.perf_counter()
|
344 |
+
|
345 |
+
progress(0.1, desc= "Loading model and creating new representation")
|
346 |
+
|
347 |
+
representation_model = create_representation_model(representation_type, llm_config, hf_model_name, hf_model_file, chosen_start_tag, low_resource_mode)
|
348 |
+
|
349 |
+
progress(0.6, desc= "Updating existing topics")
|
350 |
+
topic_model.update_topics(docs, vectorizer_model=vectoriser_model, representation_model=representation_model)
|
351 |
+
|
352 |
+
topic_dets = topic_model.get_topic_info()
|
353 |
+
|
354 |
+
# Replace original labels with LLM labels
|
355 |
+
if representation_type == "LLM":
|
356 |
+
llm_labels = [label[0].split("\n")[0] for label in topic_dets["LLM"]]
|
357 |
+
topic_model.set_topic_labels(llm_labels)
|
358 |
+
|
359 |
+
label_list_file_name = data_file_name_no_ext + '_llm_topic_list_' + today_rev + '.csv'
|
360 |
+
|
361 |
+
llm_labels_df = pd.DataFrame(data={"Label":llm_labels})
|
362 |
+
llm_labels_df.to_csv(label_list_file_name, index=None)
|
363 |
+
|
364 |
+
output_list.append(label_list_file_name)
|
365 |
+
else:
|
366 |
+
new_topic_labels = topic_model.generate_topic_labels(nr_words=3, separator=", ", aspect = representation_type)
|
367 |
+
|
368 |
+
topic_model.set_topic_labels(new_topic_labels)#list(topic_dets[representation_type]))
|
369 |
+
#topic_model.set_topic_labels(list(topic_dets["Name"]))
|
370 |
+
|
371 |
+
# Outputs
|
372 |
+
progress(0.8, desc= "Saving outputs")
|
373 |
+
output_list, output_text = save_topic_outputs(topic_model, data_file_name_no_ext, output_list, docs, save_topic_model)
|
374 |
+
|
375 |
+
all_toc = time.perf_counter()
|
376 |
+
time_out = f"All processes took {all_toc - all_tic:0.1f} seconds"
|
377 |
+
print(time_out)
|
378 |
+
|
379 |
+
return output_text, output_list, topic_model
|
380 |
+
|
381 |
+
def visualise_topics(topic_model, data, data_file_name_no_ext, low_resource_mode, embeddings_out, in_label, in_colnames, legend_label, sample_prop, visualisation_type_radio, random_seed, progress=gr.Progress(track_tqdm=True)):
|
382 |
+
|
383 |
+
progress(0, desc= "Preparing data for visualisation")
|
384 |
+
|
385 |
+
output_list = []
|
386 |
+
vis_tic = time.perf_counter()
|
387 |
+
|
388 |
+
from funcs.bertopic_vis_documents import visualize_documents_custom, visualize_hierarchical_documents_custom, visualize_barchart_custom
|
389 |
+
|
390 |
+
if not visualisation_type_radio:
|
391 |
+
return "Please choose a visualisation type above.", output_list, None, None
|
392 |
+
|
393 |
+
# Get topic labels
|
394 |
+
if in_label:
|
395 |
+
in_label_list_first = in_label[0]
|
396 |
+
else:
|
397 |
+
return "Label column not found. Please enter this above.", output_list, None, None
|
398 |
+
|
399 |
+
# Get docs
|
400 |
+
if in_colnames:
|
401 |
+
in_colnames_list_first = in_colnames[0]
|
402 |
+
else:
|
403 |
+
return "Label column not found. Please enter this on the data load tab.", output_list, None, None
|
404 |
+
|
405 |
+
docs = list(data[in_colnames_list_first].str.lower())
|
406 |
+
|
407 |
+
# Make sure format of input series is good
|
408 |
+
data[in_label_list_first] = data[in_label_list_first].fillna('').astype(str)
|
409 |
+
label_list = list(data[in_label_list_first])
|
410 |
+
|
411 |
+
topic_dets = topic_model.get_topic_info()
|
412 |
+
|
413 |
+
# Replace original labels with another representation if specified
|
414 |
+
if legend_label:
|
415 |
+
topic_dets = topic_model.get_topics(full=True)
|
416 |
+
if legend_label in topic_dets:
|
417 |
+
labels = [topic_dets[legend_label].values()]
|
418 |
+
labels = [str(v) for v in labels]
|
419 |
+
topic_model.set_topic_labels(labels)
|
420 |
+
|
421 |
+
# Pre-reduce embeddings for visualisation purposes
|
422 |
+
if low_resource_mode == "No":
|
423 |
+
reduced_embeddings = UMAP(n_neighbors=15, n_components=2, min_dist=0.0, metric='cosine', random_state=random_seed).fit_transform(embeddings_out)
|
424 |
+
else:
|
425 |
+
reduced_embeddings = TruncatedSVD(2, random_state=random_seed).fit_transform(embeddings_out)
|
426 |
+
|
427 |
+
progress(0.5, desc= "Creating visualisation (this can take a while)")
|
428 |
+
# Visualise the topics:
|
429 |
+
|
430 |
+
print("Creating visualisation")
|
431 |
+
|
432 |
+
# "Topic document graph", "Hierarchical view"
|
433 |
+
|
434 |
+
if visualisation_type_radio == "Topic document graph":
|
435 |
+
topics_vis = visualize_documents_custom(topic_model, docs, hover_labels = label_list, reduced_embeddings=reduced_embeddings, hide_annotations=True, hide_document_hover=False, custom_labels=True, sample = sample_prop, width= 1200, height = 750)
|
436 |
+
|
437 |
+
topics_vis_name = data_file_name_no_ext + '_' + 'vis_topic_docs_' + today_rev + '.html'
|
438 |
+
topics_vis.write_html(topics_vis_name)
|
439 |
+
output_list.append(topics_vis_name)
|
440 |
+
|
441 |
+
topics_vis_2 = topic_model.visualize_heatmap(custom_labels=True, width= 1200, height = 1200)
|
442 |
+
|
443 |
+
topics_vis_2_name = data_file_name_no_ext + '_' + 'vis_heatmap_' + today_rev + '.html'
|
444 |
+
topics_vis_2.write_html(topics_vis_2_name)
|
445 |
+
output_list.append(topics_vis_2_name)
|
446 |
+
|
447 |
+
elif visualisation_type_radio == "Hierarchical view":
|
448 |
+
|
449 |
+
hierarchical_topics = topic_model.hierarchical_topics(docs)
|
450 |
+
|
451 |
+
# Save new hierarchical topic model to file
|
452 |
+
hierarchical_topics_name = data_file_name_no_ext + '_' + 'vis_hierarchy_topics_' + today_rev + '.csv'
|
453 |
+
hierarchical_topics.to_csv(hierarchical_topics_name)
|
454 |
+
output_list.append(hierarchical_topics_name)
|
455 |
+
|
456 |
+
try:
|
457 |
+
topics_vis = visualize_hierarchical_documents_custom(topic_model, docs, label_list, hierarchical_topics, reduced_embeddings=reduced_embeddings, sample = sample_prop, hide_document_hover= False, custom_labels=True, width= 1200, height = 750)
|
458 |
+
topics_vis_2 = topic_model.visualize_hierarchy(hierarchical_topics=hierarchical_topics, width= 1200, height = 750)
|
459 |
+
except:
|
460 |
+
error_message = "Visualisation preparation failed. Perhaps you need more topics to create the full hierarchy (more than 10)?"
|
461 |
+
return error_message, output_list, None, None
|
462 |
+
|
463 |
+
topics_vis_name = data_file_name_no_ext + '_' + 'vis_hierarchy_topic_doc_' + today_rev + '.html'
|
464 |
+
topics_vis.write_html(topics_vis_name)
|
465 |
+
output_list.append(topics_vis_name)
|
466 |
+
|
467 |
+
topics_vis_2_name = data_file_name_no_ext + '_' + 'vis_hierarchy_' + today_rev + '.html'
|
468 |
+
topics_vis_2.write_html(topics_vis_2_name)
|
469 |
+
output_list.append(topics_vis_2_name)
|
470 |
+
|
471 |
+
all_toc = time.perf_counter()
|
472 |
+
time_out = f"Creating visualisation took {all_toc - vis_tic:0.1f} seconds"
|
473 |
+
print(time_out)
|
474 |
+
|
475 |
+
return time_out, output_list, topics_vis, topics_vis_2
|
476 |
+
|
477 |
+
def save_as_pytorch_model(topic_model, data_file_name_no_ext , progress=gr.Progress(track_tqdm=True)):
|
478 |
+
|
479 |
+
if not topic_model:
|
480 |
+
return "No Pytorch model found.", None
|
481 |
+
|
482 |
+
progress(0, desc= "Saving topic model in Pytorch format")
|
483 |
+
|
484 |
+
output_list = []
|
485 |
+
|
486 |
+
|
487 |
+
topic_model_save_name_folder = "output_model/" + data_file_name_no_ext + "_topics_" + today_rev# + ".safetensors"
|
488 |
+
topic_model_save_name_zip = topic_model_save_name_folder + ".zip"
|
489 |
+
|
490 |
+
# Clear folder before replacing files
|
491 |
+
delete_files_in_folder(topic_model_save_name_folder)
|
492 |
+
|
493 |
+
topic_model.save(topic_model_save_name_folder, serialization='pytorch', save_embedding_model=True, save_ctfidf=False)
|
494 |
+
|
495 |
+
# Zip file example
|
496 |
+
|
497 |
+
zip_folder(topic_model_save_name_folder, topic_model_save_name_zip)
|
498 |
+
output_list.append(topic_model_save_name_zip)
|
499 |
+
|
500 |
+
return "Model saved in Pytorch format.", output_list
|
requirements.txt
CHANGED
@@ -2,11 +2,14 @@ gradio==3.50.0
|
|
2 |
transformers==4.37.1
|
3 |
accelerate==0.26.1
|
4 |
torch==2.1.2
|
5 |
-
llama-cpp-python==0.2.
|
6 |
bertopic==0.16.0
|
7 |
spacy==3.7.2
|
|
|
8 |
pyarrow==14.0.2
|
|
|
9 |
Faker==22.2.0
|
10 |
presidio_analyzer==2.2.351
|
11 |
presidio_anonymizer==2.2.351
|
12 |
scipy==1.11.4
|
|
|
|
2 |
transformers==4.37.1
|
3 |
accelerate==0.26.1
|
4 |
torch==2.1.2
|
5 |
+
llama-cpp-python==0.2.36
|
6 |
bertopic==0.16.0
|
7 |
spacy==3.7.2
|
8 |
+
en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1.tar.gz
|
9 |
pyarrow==14.0.2
|
10 |
+
openpyxl==3.1.2
|
11 |
Faker==22.2.0
|
12 |
presidio_analyzer==2.2.351
|
13 |
presidio_anonymizer==2.2.351
|
14 |
scipy==1.11.4
|
15 |
+
polars==0.20.6
|