AdhyaSuman commited on
Commit
dd63f62
·
2 Parent(s): 1ad4706 11c72a2

Merge master into main, resolved conflicts and updated LFS tracking

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +5 -0
  2. .huggingface.yaml +3 -0
  3. LICENSE +21 -0
  4. app/ui_updated.py +450 -0
  5. assets/Logo_light.png +3 -0
  6. backend/__init__.py +81 -0
  7. backend/datasets/_preprocess.py +447 -0
  8. backend/datasets/data/download.py +32 -0
  9. backend/datasets/data/file_utils.py +39 -0
  10. backend/datasets/dynamic_dataset.py +90 -0
  11. backend/datasets/preprocess.py +362 -0
  12. backend/datasets/utils/_utils.py +37 -0
  13. backend/datasets/utils/logger.py +29 -0
  14. backend/evaluation/CoherenceModel_ttc.py +862 -0
  15. backend/evaluation/eval.py +179 -0
  16. backend/inference/doc_retriever.py +219 -0
  17. backend/inference/indexing_utils.py +146 -0
  18. backend/inference/peak_detector.py +18 -0
  19. backend/inference/process_beta.py +33 -0
  20. backend/inference/word_selector.py +102 -0
  21. backend/llm/custom_gemini.py +28 -0
  22. backend/llm/custom_mistral.py +27 -0
  23. backend/llm/llm_router.py +73 -0
  24. backend/llm_utils/label_generator.py +72 -0
  25. backend/llm_utils/summarizer.py +192 -0
  26. backend/llm_utils/token_utils.py +167 -0
  27. backend/models/CFDTM/CFDTM.py +127 -0
  28. backend/models/CFDTM/ETC.py +62 -0
  29. backend/models/CFDTM/Encoder.py +40 -0
  30. backend/models/CFDTM/UWE.py +48 -0
  31. backend/models/CFDTM/__init__.py +0 -0
  32. backend/models/CFDTM/__pycache__/CFDTM.cpython-39.pyc +0 -0
  33. backend/models/CFDTM/__pycache__/ETC.cpython-39.pyc +0 -0
  34. backend/models/CFDTM/__pycache__/Encoder.cpython-39.pyc +0 -0
  35. backend/models/CFDTM/__pycache__/UWE.cpython-39.pyc +0 -0
  36. backend/models/CFDTM/__pycache__/__init__.cpython-39.pyc +0 -0
  37. backend/models/DBERTopic_trainer.py +99 -0
  38. backend/models/DETM.py +259 -0
  39. backend/models/DTM_trainer.py +148 -0
  40. backend/models/dynamic_trainer.py +177 -0
  41. data/ACL_Anthology/CFDTM/beta.npy +3 -0
  42. data/ACL_Anthology/DETM/beta.npy +3 -0
  43. data/ACL_Anthology/DTM/beta.npy +3 -0
  44. data/ACL_Anthology/DTM/topic_label_cache.json +3 -0
  45. data/ACL_Anthology/docs.jsonl +3 -0
  46. data/ACL_Anthology/inverted_index.json +3 -0
  47. data/ACL_Anthology/processed/lemma_to_forms.json +3 -0
  48. data/ACL_Anthology/processed/length_stats.json +3 -0
  49. data/ACL_Anthology/processed/time2id.txt +18 -0
  50. data/ACL_Anthology/processed/vocab.txt +0 -0
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/**/*.npy filter=lfs diff=lfs merge=lfs -text
37
+ data/**/*.jsonl filter=lfs diff=lfs merge=lfs -text
38
+ data/**/*.json filter=lfs diff=lfs merge=lfs -text
39
+ assets/*.png filter=lfs diff=lfs merge=lfs -text
40
+ data/**/*.npz filter=lfs diff=lfs merge=lfs -text
.huggingface.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # .huggingface.yaml
2
+ sdk: streamlit # or gradio
3
+ app_file: ./app/ui.py
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Suman Adhya
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
app/ui_updated.py ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import plotly.graph_objects as go
3
+ import plotly.colors as pc
4
+ import sys
5
+ import os
6
+ import base64
7
+ import streamlit.components.v1 as components
8
+ import html
9
+
10
+ # Absolute path to the repo root (assuming `ui.py` is in /app)
11
+ REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
12
+ sys.path.append(REPO_ROOT)
13
+ ASSETS_DIR = os.path.join(REPO_ROOT, 'assets')
14
+ DATA_DIR = os.path.join(REPO_ROOT, 'data')
15
+
16
+ # Import functions from the backend
17
+ from backend.inference.process_beta import (
18
+ load_beta_matrix,
19
+ get_top_words_over_time,
20
+ load_time_labels
21
+ )
22
+ from backend.inference.word_selector import get_interesting_words, get_word_trend
23
+ from backend.inference.indexing_utils import load_index
24
+ from backend.inference.doc_retriever import (
25
+ load_length_stats,
26
+ get_yearly_counts_for_word,
27
+ deduplicate_docs,
28
+ get_all_documents_for_word_year,
29
+ highlight_words,
30
+ extract_snippet
31
+ )
32
+ from backend.llm_utils.summarizer import summarize_multiword_docs, ask_multiturn_followup
33
+ from backend.llm_utils.label_generator import get_topic_labels
34
+ from backend.llm.llm_router import get_llm, list_supported_models
35
+ from backend.llm_utils.token_utils import estimate_k_max_from_word_stats
36
+
37
+ def get_base64_image(image_path):
38
+ with open(image_path, "rb") as img_file:
39
+ return base64.b64encode(img_file.read()).decode()
40
+
41
+ # --- Page Configuration ---
42
+ st.set_page_config(
43
+ page_title="DTECT",
44
+ page_icon="🔍",
45
+ layout="wide"
46
+ )
47
+
48
+ # Sidebar branding and repo link
49
+ st.sidebar.markdown(
50
+ """
51
+ <div style="text-align: center;">
52
+ <a href="https://github.com/dinb-ai/DTECT" target="_blank">
53
+ <img src="data:image/png;base64,{}" width="180" style="margin-bottom: 18px;">
54
+ </a>
55
+ <hr style="margin-bottom: 0;">
56
+ </div>
57
+ """.format(get_base64_image(os.path.join(ASSETS_DIR, 'Logo_light.png'))),
58
+ unsafe_allow_html=True
59
+ )
60
+
61
+ # 1. Sidebar: Model and Dataset Selection
62
+ st.sidebar.title("Configuration")
63
+
64
+ AVAILABLE_MODELS = ["DTM", "DETM", "CFDTM"]
65
+ ENV_VAR_MAP = {
66
+ "OpenAI": "OPENAI_API_KEY",
67
+ "Anthropic": "ANTHROPIC_API_KEY",
68
+ "Gemini": "GEMINI_API_KEY",
69
+ "Mistral": "MISTRAL_API_KEY"
70
+ }
71
+
72
+ def list_datasets(data_dir):
73
+ return sorted([
74
+ name for name in os.listdir(data_dir)
75
+ if os.path.isdir(os.path.join(data_dir, name))
76
+ ])
77
+
78
+ with st.sidebar.expander("Select Dataset & Topic Model", expanded=True):
79
+ datasets = list_datasets(DATA_DIR)
80
+ selected_dataset = st.selectbox("Dataset", datasets, help="Choose an available dataset.")
81
+ selected_model = st.selectbox("Model", AVAILABLE_MODELS, help="Select topic model architecture.")
82
+
83
+ # Resolve paths
84
+ dataset_path = os.path.join(DATA_DIR, selected_dataset)
85
+ model_path = os.path.join(dataset_path, selected_model)
86
+ docs_path = os.path.join(dataset_path, "docs.jsonl")
87
+ vocab_path = os.path.join(dataset_path, "processed/vocab.txt")
88
+ time2id_path = os.path.join(dataset_path, "processed/time2id.txt")
89
+ index_path = os.path.join(dataset_path, "inverted_index.json")
90
+ beta_path = os.path.join(model_path, "beta.npy")
91
+ label_cache_path = os.path.join(model_path, "topic_label_cache.json")
92
+ length_stats_path = os.path.join(dataset_path, "processed/length_stats.json")
93
+ lemma_map_path = os.path.join(dataset_path, "processed/lemma_to_forms.json")
94
+
95
+ with st.sidebar.expander("LLM Settings", expanded=True):
96
+ provider = st.selectbox("LLM Provider", options=list(ENV_VAR_MAP.keys()), help="Choose the LLM backend.")
97
+ available_models = list_supported_models(provider)
98
+ model = st.selectbox("LLM Model", options=available_models)
99
+ env_var = ENV_VAR_MAP[provider]
100
+ api_key = os.getenv(env_var)
101
+
102
+ if "llm_configured" not in st.session_state:
103
+ st.session_state.llm_configured = False
104
+
105
+ if api_key:
106
+ st.session_state.llm_configured = True
107
+ else:
108
+ st.session_state.llm_configured = False
109
+ with st.form(key="api_key_form"):
110
+ entered_key = st.text_input(f"Enter your {provider} API Key", type="password")
111
+ submitted = st.form_submit_button("Submit and Confirm")
112
+ if submitted:
113
+ if entered_key:
114
+ os.environ[env_var] = entered_key
115
+ api_key = entered_key
116
+ st.session_state.llm_configured = True
117
+ st.rerun()
118
+ else:
119
+ st.warning("Please enter a key.")
120
+
121
+ if not st.session_state.llm_configured:
122
+ st.warning("Please configure your LLM settings in the sidebar.")
123
+ st.stop()
124
+
125
+ if api_key and not st.session_state.llm_configured:
126
+ st.session_state.llm_configured = True
127
+
128
+ if not api_key:
129
+ st.session_state.llm_configured = False
130
+
131
+ if not st.session_state.llm_configured:
132
+ st.warning("Please configure your LLM settings in the sidebar.")
133
+ st.stop()
134
+
135
+ # Initialize LLM with the provided key
136
+ llm = get_llm(provider=provider, model=model, api_key=api_key)
137
+
138
+ # 3. Load Data
139
+ @st.cache_resource
140
+ def load_resources(beta_path, vocab_path, docs_path, index_path, time2id_path, length_stats_path, lemma_map_path):
141
+ beta, vocab = load_beta_matrix(beta_path, vocab_path)
142
+ index, docs, lemma_to_forms = load_index(docs_file_path=docs_path, vocab=vocab, index_path=index_path, lemma_map_path=lemma_map_path)
143
+ time_labels = load_time_labels(time2id_path)
144
+ length_stats = load_length_stats(length_stats_path)
145
+ return beta, vocab, index, docs, lemma_to_forms, time_labels, length_stats
146
+
147
+ # --- Main Title and Paper-aligned Intro ---
148
+ st.markdown("""# 🔍 DTECT: Dynamic Topic Explorer & Context Tracker""")
149
+
150
+ # --- Load resources ---
151
+ try:
152
+ beta, vocab, index, docs, lemma_to_forms, time_labels, length_stats = load_resources(
153
+ beta_path,
154
+ vocab_path,
155
+ docs_path,
156
+ index_path,
157
+ time2id_path,
158
+ length_stats_path,
159
+ lemma_map_path
160
+ )
161
+ except FileNotFoundError as e:
162
+ st.error(f"Missing required file: {e}")
163
+ st.stop()
164
+ except Exception as e:
165
+ st.error(f"Failed to load data: {str(e)}")
166
+ st.stop()
167
+
168
+ timestamps = list(range(len(time_labels)))
169
+ num_topics = beta.shape[1]
170
+ # Estimate max_k based on document length stats and selected LLM
171
+ suggested_max_k = estimate_k_max_from_word_stats(length_stats.get("avg_len"), model_name=model, provider=provider)
172
+
173
+
174
+ # ==============================================================================
175
+ # 1. 🏷 TOPIC LABELING
176
+ # ==============================================================================
177
+ st.markdown("## 1️⃣ 🏷️ Topic Labeling")
178
+ st.info("Topics are automatically labeled using LLMs by analyzing their temporal word distributions.")
179
+
180
+ topic_labels = get_topic_labels(beta, vocab, time_labels, llm, label_cache_path)
181
+ topic_options = list(topic_labels.values())
182
+ selected_topic_label = st.selectbox("Select a Topic", topic_options, help="LLM-generated topic label")
183
+ label_to_topic = {v: k for k, v in topic_labels.items()}
184
+ selected_topic = label_to_topic[selected_topic_label]
185
+
186
+ # ==============================================================================
187
+ # 2. 💡 INFORMATIVE WORD DETECTION & 📊 TREND VISUALIZATION
188
+ # ==============================================================================
189
+ st.markdown("---")
190
+ st.markdown("## 2️⃣ 💡 Informative Word Detection & 📊 Trend Visualization")
191
+ st.info("Explore top/interesting words for each topic, and visualize their trends over time.")
192
+
193
+ top_n_words = st.slider("Number of Top Words per Topic", min_value=5, max_value=500, value=10)
194
+ top_words = get_top_words_over_time(
195
+ beta=beta,
196
+ vocab=vocab,
197
+ topic_id=selected_topic,
198
+ top_n=top_n_words
199
+ )
200
+
201
+ st.write(f"### Top {top_n_words} Words for Topic '{selected_topic_label}' (Ranked):")
202
+ scrollable_top_words = "<div style='max-height: 200px; overflow-y: auto; padding: 0 10px;'>"
203
+ words_per_col = (top_n_words + 3) // 4
204
+ columns = [top_words[i:i+words_per_col] for i in range(0, len(top_words), words_per_col)]
205
+ scrollable_top_words += "<div style='display: flex; gap: 20px;'>"
206
+ word_rank = 1
207
+ for col in columns:
208
+ scrollable_top_words += "<div style='flex: 1;'>"
209
+ for word in col:
210
+ scrollable_top_words += f"<div style='margin-bottom: 4px;'>{word_rank}. {word}</div>"
211
+ word_rank += 1
212
+ scrollable_top_words += "</div>"
213
+ scrollable_top_words += "</div></div>"
214
+ st.markdown(scrollable_top_words, unsafe_allow_html=True)
215
+
216
+ st.markdown("<div style='margin-top: 18px;'></div>", unsafe_allow_html=True)
217
+
218
+ if st.button("💡 Suggest Informative Words", key="suggest_topic_words"):
219
+ top_words = get_top_words_over_time(
220
+ beta=beta,
221
+ vocab=vocab,
222
+ topic_id=selected_topic,
223
+ top_n=top_n_words
224
+ )
225
+ interesting_words = get_interesting_words(beta, vocab, topic_id=selected_topic, restrict_to=top_words)
226
+ st.session_state.interesting_words = interesting_words
227
+ st.session_state.selected_words = interesting_words[:15] # pre-fill multiselect
228
+ styled_words = " ".join([
229
+ f"<span style='background-color:#e0f7fa; color:#004d40; font-weight:500; padding:4px 8px; margin:4px; border-radius:8px; display:inline-block;'>{w}</span>"
230
+ for w in interesting_words
231
+ ])
232
+ st.markdown(
233
+ f"**Top Informative Words from Topic '{selected_topic_label}':**<br>{styled_words}",
234
+ unsafe_allow_html=True
235
+ )
236
+
237
+ st.markdown("#### 📈 Plot Word Trends Over Time")
238
+ all_word_options = vocab
239
+ interesting_words = st.session_state.get("interesting_words", [])
240
+
241
+ if "selected_words" not in st.session_state:
242
+ st.session_state.selected_words = interesting_words[:15] # initial default
243
+
244
+ selected_words = st.multiselect(
245
+ "Select words to visualize trends",
246
+ options=all_word_options,
247
+ default=st.session_state.selected_words,
248
+ key="selected_words"
249
+ )
250
+ if selected_words:
251
+ fig = go.Figure()
252
+ color_cycle = pc.qualitative.Plotly
253
+ for i, word in enumerate(selected_words):
254
+ trend = get_word_trend(beta, vocab, word, topic_id=selected_topic)
255
+ color = color_cycle[i % len(color_cycle)]
256
+ fig.add_trace(go.Scatter(
257
+ x=time_labels,
258
+ y=trend,
259
+ name=word,
260
+ line=dict(color=color),
261
+ legendgroup=word,
262
+ showlegend=True
263
+ ))
264
+ fig.update_layout(title="", xaxis_title="Year", yaxis_title="Importance")
265
+ st.plotly_chart(fig, use_container_width=True)
266
+
267
+ # ==============================================================================
268
+ # 3. 🔍 DOCUMENT RETRIEVAL & 📃 SUMMARIZATION
269
+ # ==============================================================================
270
+ st.markdown("---")
271
+ st.markdown("## 3️⃣ 🔍 Document Retrieval & 📃 Summarization")
272
+ st.info("Retrieve and summarize documents matching selected words and years.")
273
+
274
+ if selected_words:
275
+ st.markdown("#### 📊 Document Frequency Over Time")
276
+ selected_words_for_counts = st.multiselect(
277
+ "Select word(s) to show document frequencies over time",
278
+ options=selected_words,
279
+ default=selected_words[:3],
280
+ key="word_counts_multiselect"
281
+ )
282
+
283
+ if selected_words_for_counts:
284
+ color_cycle = pc.qualitative.Set2
285
+ bar_fig = go.Figure()
286
+ for i, word in enumerate(selected_words_for_counts):
287
+ doc_years, doc_counts = get_yearly_counts_for_word(index=index, word=word)
288
+ bar_fig.add_trace(go.Bar(
289
+ x=doc_years,
290
+ y=doc_counts,
291
+ name=word,
292
+ marker_color=color_cycle[i % len(color_cycle)],
293
+ opacity=0.85
294
+ ))
295
+ bar_fig.update_layout(
296
+ barmode="group",
297
+ title="Document Frequency Over Time",
298
+ xaxis_title="Year",
299
+ yaxis_title="Document Count",
300
+ xaxis=dict(
301
+ tickmode='linear',
302
+ dtick=1,
303
+ tickformat='d'
304
+ ),
305
+ bargap=0.2
306
+ )
307
+ st.plotly_chart(bar_fig, use_container_width=True)
308
+
309
+ st.markdown("#### 📄 Inspect Documents for Word-Year Pairs")
310
+ # selected_year = st.slider("Select year", min_value=int(time_labels[0]), max_value=int(time_labels[-1]), key="inspect_year_slider")
311
+ selected_year = st.selectbox(
312
+ "Select year",
313
+ options=time_labels, # Use the list of available time labels (years)
314
+ index=0, # Default to the first year in the list
315
+ key="inspect_year_selectbox"
316
+ )
317
+ collected_docs_raw = []
318
+ for word in selected_words_for_counts:
319
+ docs_for_word_year = get_all_documents_for_word_year(
320
+ index=index,
321
+ docs_file_path=docs_path,
322
+ word=word,
323
+ year=selected_year
324
+ )
325
+ for doc in docs_for_word_year:
326
+ doc["__word__"] = word
327
+ collected_docs_raw.extend(docs_for_word_year)
328
+
329
+ if collected_docs_raw:
330
+ st.session_state.collected_deduplicated_docs = deduplicate_docs(collected_docs_raw)
331
+ st.write(f"Found {len(collected_docs_raw)} matching documents, {len(st.session_state.collected_deduplicated_docs)} after deduplication.")
332
+
333
+ html_blocks = ""
334
+ for doc in st.session_state.collected_deduplicated_docs:
335
+ word = doc["__word__"]
336
+ full_text = html.escape(doc["text"])
337
+ snippet_text = extract_snippet(doc["text"], word)
338
+ highlighted_snippet = highlight_words(
339
+ snippet_text,
340
+ query_words=selected_words_for_counts,
341
+ lemma_to_forms=lemma_to_forms
342
+ )
343
+ html_blocks += f"""
344
+ <div style="margin-bottom: 14px; padding: 10px; background-color: #fffbe6; border: 1px solid #f0e6cc; border-radius: 6px;">
345
+ <div style="color: #333;"><strong>Match:</strong> {word} | <strong>Doc ID:</strong> {doc['id']} | <strong>Timestamp:</strong> {doc['timestamp']}</div>
346
+ <div style="margin-top: 4px; color: #444;"><em>Snippet:</em> {highlighted_snippet}</div>
347
+ <details style="margin-top: 4px;">
348
+ <summary style="cursor: pointer; color: #007acc;">Show full document</summary>
349
+ <pre style="white-space: pre-wrap; color: #111; background-color: #fffef5; padding: 8px; border: 1px solid #f0e6cc; border-radius: 4px;">{full_text}</pre>
350
+ </details>
351
+ </div>
352
+ """
353
+ min_height = 120
354
+ max_height = 700
355
+ per_doc_height = 130
356
+ dynamic_height = min_height + per_doc_height * max(len(st.session_state.collected_deduplicated_docs) - 1, 0)
357
+ container_height = min(dynamic_height, max_height)
358
+ scrollable_html = f"""
359
+ <div style="overflow-y: auto; padding: 10px;
360
+ border: 1px solid #f0e6cc; border-radius: 6px;
361
+ background-color: #fffbe6; color: #222;
362
+ margin-bottom: 0;">
363
+ {html_blocks}
364
+ </div>
365
+ """
366
+ components.html(scrollable_html, height=container_height, scrolling=True)
367
+ else:
368
+ st.warning("No documents found for the selected words and year.")
369
+
370
+ # ==============================================================================
371
+ # 4. 💬 CHAT ASSISTANT (Summary & Follow-up)
372
+ # ==============================================================================
373
+ st.markdown("---")
374
+ st.markdown("## 4️⃣ 💬 Chat Assistant")
375
+ st.info("Generate summaries from the inspected documents and ask follow-up questions.")
376
+
377
+ if "summary" not in st.session_state:
378
+ st.session_state.summary = None
379
+ if "context_for_followup" not in st.session_state:
380
+ st.session_state.context_for_followup = ""
381
+ if "followup_history" not in st.session_state:
382
+ st.session_state.followup_history = []
383
+
384
+ # MMR K selection
385
+ st.markdown(f"**Max documents for summarization (k):**")
386
+ st.markdown(f"The suggested maximum number of documents for summarization (k) based on the average document length and the selected LLM is **{suggested_max_k}**.")
387
+ mmr_k = st.slider(
388
+ "Select the maximum number of documents (k) for MMR (Maximum Marginal Relevance) selection for summarization.",
389
+ min_value=1,
390
+ max_value=20, # Set a reasonable max for k, can be adjusted
391
+ value=min(suggested_max_k, 20), # Use suggested_max_k as default, capped at 20
392
+ help="This value determines how many relevant and diverse documents will be selected for summarization."
393
+ )
394
+
395
+ if st.button("📃 Summarize These Documents"):
396
+ if st.session_state.get("collected_deduplicated_docs"):
397
+ st.session_state.summary = None
398
+ st.session_state.context_for_followup = ""
399
+ st.session_state.followup_history = []
400
+ with st.spinner("Selecting and summarizing documents..."):
401
+ summary, mmr_docs = summarize_multiword_docs(
402
+ selected_words_for_counts,
403
+ selected_year,
404
+ st.session_state.collected_deduplicated_docs,
405
+ llm,
406
+ k=mmr_k
407
+ )
408
+ st.session_state.summary = summary
409
+ st.session_state.context_for_followup = "\n".join(
410
+ f"Document {i+1}:\n{doc.page_content.strip()}" for i, doc in enumerate(mmr_docs)
411
+ )
412
+ st.session_state.followup_history.append(
413
+ {"role": "user", "content": f"Please summarize the context of the words '{', '.join(selected_words_for_counts)}' in {selected_year} based on the provided documents."}
414
+ )
415
+ st.session_state.followup_history.append(
416
+ {"role": "assistant", "content": st.session_state.summary}
417
+ )
418
+ st.success(f"✅ Summary generated from {len(mmr_docs)} MMR-selected documents.")
419
+ else:
420
+ st.warning("⚠️ No documents collected to summarize. Please inspect some documents first.")
421
+
422
+ if st.session_state.summary:
423
+ st.markdown(f"**Summary for words `{', '.join(selected_words_for_counts)}` in `{selected_year}`:**")
424
+ st.write(st.session_state.summary)
425
+
426
+ if st.checkbox("💬 Ask follow-up questions about this summary", key="enable_followup"):
427
+ with st.expander("View the documents used for this conversation"):
428
+ st.text_area("Context Documents", st.session_state.context_for_followup, height=200)
429
+ st.info("Ask a question based on the summary and the documents above.")
430
+ for msg in st.session_state.followup_history[2:]:
431
+ with st.chat_message(msg["role"], avatar="🧑" if msg["role"] == "user" else "🤖"):
432
+ st.markdown(msg["content"])
433
+ if user_query := st.chat_input("Ask a follow-up question..."):
434
+ with st.chat_message("user", avatar="🧑"):
435
+ st.markdown(user_query)
436
+ st.session_state.followup_history.append({"role": "user", "content": user_query})
437
+ with st.spinner("Thinking..."):
438
+ followup_response = ask_multiturn_followup(
439
+ history=st.session_state.followup_history,
440
+ question=user_query,
441
+ llm=llm,
442
+ context_texts=st.session_state.context_for_followup
443
+ )
444
+ st.session_state.followup_history.append({"role": "assistant", "content": followup_response})
445
+ if followup_response.startswith("[Error"):
446
+ st.error(followup_response)
447
+ else:
448
+ with st.chat_message("assistant", avatar="🤖"):
449
+ st.markdown(followup_response)
450
+ st.rerun()
assets/Logo_light.png ADDED

Git LFS Details

  • SHA256: 4237eeb306339868507feb9ae60b6c5ab5980abe769d8b26d3635eed55e9714f
  • Pointer size: 131 Bytes
  • Size of remote file: 317 kB
backend/__init__.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # === Inference components ===
2
+ from .inference.process_beta import (
3
+ load_beta_matrix,
4
+ get_top_words_at_time,
5
+ get_top_words_over_time,
6
+ load_time_labels
7
+ )
8
+
9
+ from .inference.indexing_utils import load_index
10
+ from .inference.word_selector import (
11
+ get_interesting_words,
12
+ get_word_trend
13
+ )
14
+ from .inference.peak_detector import detect_peaks
15
+ from .inference.doc_retriever import (
16
+ load_length_stats,
17
+ get_yearly_counts_for_word,
18
+ get_all_documents_for_word_year,
19
+ deduplicate_docs,
20
+ extract_snippet,
21
+ highlight,
22
+ get_docs_by_ids,
23
+ )
24
+
25
+ # === LLM components ===
26
+ from .llm_utils.label_generator import label_topic_temporal, get_topic_labels
27
+ from .llm_utils.token_utils import (
28
+ get_token_limit_for_model,
29
+ count_tokens,
30
+ estimate_avg_tokens_per_doc,
31
+ estimate_max_k,
32
+ estimate_max_k_fast
33
+ )
34
+ from .llm_utils.summarizer import (
35
+ summarize_docs,
36
+ summarize_multiword_docs,
37
+ ask_multiturn_followup
38
+ )
39
+ from .llm.llm_router import (
40
+ list_supported_models,
41
+ get_llm
42
+ )
43
+
44
+ # === Dataset utilities ===
45
+ from .datasets import dynamic_dataset
46
+ from .datasets import preprocess
47
+ from .datasets.utils import logger, _utils
48
+ from .datasets.data import file_utils, download
49
+
50
+ # === Evaluation ===
51
+ from .evaluation.CoherenceModel_ttc import CoherenceModel_ttc
52
+ from .evaluation.eval import TopicQualityAssessor
53
+
54
+ # === Models ===
55
+ from .models.DETM import DETM
56
+ from .models.DTM_trainer import DTMTrainer
57
+ from .models.CFDTM.CFDTM import CFDTM
58
+ from .models.dynamic_trainer import DynamicTrainer
59
+
60
+ __all__ = [
61
+ # Inference
62
+ "load_beta_matrix", "load_time_labels", "get_top_words_at_time", "get_top_words_over_time",
63
+ "load_index", "get_interesting_words", "get_word_trend", "detect_peaks",
64
+ "load_length_stats", "get_yearly_counts_for_word", "get_all_documents_for_word_year",
65
+ "deduplicate_docs", "extract_snippet", "highlight", "get_docs_by_ids",
66
+
67
+ # LLM
68
+ "summarize_docs", "summarize_multiword_docs", "ask_multiturn_followup",
69
+ "get_token_limit_for_model", "list_supported_models", "get_llm",
70
+ "label_topic_temporal", "get_topic_labels", "count_tokens",
71
+ "estimate_avg_tokens_per_doc", "estimate_max_k", "estimate_max_k_fast",
72
+
73
+ # Dataset
74
+ "dynamic_dataset", "preprocess", "logger","_utils", "file_utils", "download",
75
+
76
+ # Evaluation
77
+ "CoherenceModel_ttc", "TopicQualityAssessor",
78
+
79
+ # Models
80
+ "DETM", "DTMTrainer", "CFDTM", "DynamicTrainer"
81
+ ]
backend/datasets/_preprocess.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import string
4
+ import gensim.downloader
5
+ from collections import Counter
6
+ import numpy as np
7
+ import scipy.sparse
8
+ from tqdm import tqdm
9
+ from sklearn.feature_extraction.text import CountVectorizer
10
+
11
+ from backend.datasets.data import file_utils
12
+ from backend.datasets.utils._utils import get_stopwords_set
13
+ from backend.datasets.utils.logger import Logger
14
+ import json
15
+ import nltk
16
+ from nltk.stem import WordNetLemmatizer
17
+
18
+ logger = Logger("WARNING")
19
+
20
+ try:
21
+ nltk.data.find('corpora/wordnet')
22
+ except LookupError:
23
+ nltk.download('wordnet', quiet=True)
24
+ try:
25
+ nltk.data.find('corpora/omw-1.4')
26
+ except LookupError:
27
+ nltk.download('omw-1.4', quiet=True)
28
+
29
+ # compile some regexes
30
+ punct_chars = list(set(string.punctuation) - set("'"))
31
+ punct_chars.sort()
32
+ punctuation = ''.join(punct_chars)
33
+ replace = re.compile('[%s]' % re.escape(punctuation))
34
+ alpha = re.compile('^[a-zA-Z_]+$')
35
+ alpha_or_num = re.compile('^[a-zA-Z_]+|[0-9_]+$')
36
+ alphanum = re.compile('^[a-zA-Z0-9_]+$')
37
+
38
+
39
+ class Tokenizer:
40
+ def __init__(self,
41
+ stopwords="English",
42
+ keep_num=False,
43
+ keep_alphanum=False,
44
+ strip_html=False,
45
+ no_lower=False,
46
+ min_length=3,
47
+ lemmatize=True,
48
+ ):
49
+ self.keep_num = keep_num
50
+ self.keep_alphanum = keep_alphanum
51
+ self.strip_html = strip_html
52
+ self.lower = not no_lower
53
+ self.min_length = min_length
54
+
55
+ self.stopword_set = get_stopwords_set(stopwords)
56
+
57
+ self.lemmatize = lemmatize
58
+ if lemmatize:
59
+ self.lemmatizer = WordNetLemmatizer()
60
+
61
+ def clean_text(self, text, strip_html=False, lower=True, keep_emails=False, keep_at_mentions=False):
62
+ # remove html tags
63
+ if strip_html:
64
+ text = re.sub(r'<[^>]+>', '', text)
65
+ else:
66
+ # replace angle brackets
67
+ text = re.sub(r'<', '(', text)
68
+ text = re.sub(r'>', ')', text)
69
+ # lower case
70
+ if lower:
71
+ text = text.lower()
72
+ # eliminate email addresses
73
+ if not keep_emails:
74
+ text = re.sub(r'\S+@\S+', ' ', text)
75
+ # eliminate @mentions
76
+ if not keep_at_mentions:
77
+ text = re.sub(r'\s@\S+', ' ', text)
78
+ # replace underscores with spaces
79
+ text = re.sub(r'_', ' ', text)
80
+ # break off single quotes at the ends of words
81
+ text = re.sub(r'\s\'', ' ', text)
82
+ text = re.sub(r'\'\s', ' ', text)
83
+ # remove periods
84
+ text = re.sub(r'\.', '', text)
85
+ # replace all other punctuation (except single quotes) with spaces
86
+ text = replace.sub(' ', text)
87
+ # remove single quotes
88
+ text = re.sub(r'\'', '', text)
89
+ # replace all whitespace with a single space
90
+ text = re.sub(r'\s', ' ', text)
91
+ # strip off spaces on either end
92
+ text = text.strip()
93
+ return text
94
+
95
+ def tokenize(self, text):
96
+ text = self.clean_text(text, self.strip_html, self.lower)
97
+ tokens = text.split()
98
+
99
+ tokens = ['_' if t in self.stopword_set else t for t in tokens]
100
+
101
+ # remove tokens that contain numbers
102
+ if not self.keep_alphanum and not self.keep_num:
103
+ tokens = [t if alpha.match(t) else '_' for t in tokens]
104
+
105
+ # or just remove tokens that contain a combination of letters and numbers
106
+ elif not self.keep_alphanum:
107
+ tokens = [t if alpha_or_num.match(t) else '_' for t in tokens]
108
+
109
+ # drop short tokens
110
+ if self.min_length > 0:
111
+ tokens = [t if len(t) >= self.min_length else '_' for t in tokens]
112
+
113
+ if getattr(self, "lemmatize", False):
114
+ tokens = [self.lemmatizer.lemmatize(t) if t != '_' else t for t in tokens]
115
+
116
+ unigrams = [t for t in tokens if t != '_']
117
+ return unigrams
118
+
119
+
120
+ def make_word_embeddings(vocab):
121
+ glove_vectors = gensim.downloader.load('glove-wiki-gigaword-200')
122
+ word_embeddings = np.zeros((len(vocab), glove_vectors.vectors.shape[1]))
123
+
124
+ num_found = 0
125
+
126
+ try:
127
+ key_word_list = glove_vectors.index_to_key
128
+ except:
129
+ key_word_list = glove_vectors.index2word
130
+
131
+ for i, word in enumerate(tqdm(vocab, desc="loading word embeddings")):
132
+ if word in key_word_list:
133
+ word_embeddings[i] = glove_vectors[word]
134
+ num_found += 1
135
+
136
+ logger.info(f'number of found embeddings: {num_found}/{len(vocab)}')
137
+
138
+ return scipy.sparse.csr_matrix(word_embeddings)
139
+
140
+
141
+ class Preprocess:
142
+ def __init__(self,
143
+ tokenizer=None,
144
+ test_sample_size=None,
145
+ test_p=0.2,
146
+ stopwords="English",
147
+ min_doc_count=0,
148
+ max_doc_freq=1.0,
149
+ keep_num=False,
150
+ keep_alphanum=False,
151
+ strip_html=False,
152
+ no_lower=False,
153
+ min_length=3,
154
+ min_term=0,
155
+ vocab_size=None,
156
+ seed=42,
157
+ verbose=True,
158
+ lemmatize=True,
159
+ ):
160
+ """
161
+ Args:
162
+ test_sample_size:
163
+ Size of the test set.
164
+ test_p:
165
+ Proportion of the test set. This helps sample the train set based on the size of the test set.
166
+ stopwords:
167
+ List of stopwords to exclude.
168
+ min-doc-count:
169
+ Exclude words that occur in less than this number of documents.
170
+ max_doc_freq:
171
+ Exclude words that occur in more than this proportion of documents.
172
+ keep-num:
173
+ Keep tokens made of only numbers.
174
+ keep-alphanum:
175
+ Keep tokens made of a mixture of letters and numbers.
176
+ strip_html:
177
+ Strip HTML tags.
178
+ no-lower:
179
+ Do not lowercase text
180
+ min_length:
181
+ Minimum token length.
182
+ min_term:
183
+ Minimum term number
184
+ vocab-size:
185
+ Size of the vocabulary (by most common in the union of train and test sets, following above exclusions)
186
+ seed:
187
+ Random integer seed (only relevant for choosing test set)
188
+ lemmatize:
189
+ Whether to apply lemmatization to the tokens.
190
+ """
191
+
192
+ self.test_sample_size = test_sample_size
193
+ self.min_doc_count = min_doc_count
194
+ self.max_doc_freq = max_doc_freq
195
+ self.min_term = min_term
196
+ self.test_p = test_p
197
+ self.vocab_size = vocab_size
198
+ self.seed = seed
199
+
200
+ if tokenizer is not None:
201
+ self.tokenizer = tokenizer
202
+ else:
203
+ self.tokenizer = Tokenizer(
204
+ stopwords,
205
+ keep_num,
206
+ keep_alphanum,
207
+ strip_html,
208
+ no_lower,
209
+ min_length,
210
+ lemmatize=lemmatize
211
+ ).tokenize
212
+
213
+ if verbose:
214
+ logger.set_level("DEBUG")
215
+ else:
216
+ logger.set_level("WARNING")
217
+
218
+ def parse(self, texts, vocab):
219
+ if not isinstance(texts, list):
220
+ texts = [texts]
221
+
222
+ vocab_set = set(vocab)
223
+ parsed_texts = list()
224
+ for i, text in enumerate(tqdm(texts, desc="parsing texts")):
225
+ tokens = self.tokenizer(text)
226
+ tokens = [t for t in tokens if t in vocab_set]
227
+ parsed_texts.append(" ".join(tokens))
228
+
229
+ vectorizer = CountVectorizer(vocabulary=vocab, tokenizer=lambda x: x.split())
230
+ sparse_bow = vectorizer.fit_transform(parsed_texts)
231
+ return parsed_texts, sparse_bow
232
+
233
+ def preprocess_jsonlist(self, dataset_dir, label_name=None, use_partition=True):
234
+ if use_partition:
235
+ train_items = file_utils.read_jsonlist(os.path.join(dataset_dir, 'train.jsonlist'))
236
+ test_items = file_utils.read_jsonlist(os.path.join(dataset_dir, 'test.jsonlist'))
237
+ else:
238
+ raw_path = os.path.join(dataset_dir, 'docs.jsonl')
239
+ with open(raw_path, 'r', encoding='utf-8') as f:
240
+ train_items = [json.loads(line.strip()) for line in f if line.strip()]
241
+ test_items = []
242
+
243
+ logger.info(f"Found training documents {len(train_items)} testing documents {len(test_items)}")
244
+
245
+ # Initialize containers
246
+ raw_train_texts, train_labels, raw_train_times = [], [], []
247
+ raw_test_texts, test_labels, raw_test_times = [], [], []
248
+
249
+ # Process train items
250
+ for item in train_items:
251
+ raw_train_texts.append(item['text'])
252
+ raw_train_times.append(str(item['timestamp']))
253
+ if label_name and label_name in item:
254
+ train_labels.append(item[label_name])
255
+
256
+ # Process test items
257
+ for item in test_items:
258
+ raw_test_texts.append(item['text'])
259
+ raw_test_times.append(str(item['timestamp']))
260
+ if label_name and label_name in item:
261
+ test_labels.append(item[label_name])
262
+
263
+ # Create and apply time2id mapping
264
+ all_times = sorted(set(raw_train_times + raw_test_times))
265
+ time2id = {t: i for i, t in enumerate(all_times)}
266
+
267
+ train_times = np.array([time2id[t] for t in raw_train_times], dtype=np.int32)
268
+ test_times = np.array([time2id[t] for t in raw_test_times], dtype=np.int32) if raw_test_times else None
269
+
270
+ # Preprocess and get sample indices
271
+ rst = self.preprocess(raw_train_texts, train_labels, raw_test_texts, test_labels)
272
+ train_idx = rst.get("train_idx")
273
+ test_idx = rst.get("test_idx")
274
+
275
+ # Add filtered timestamps to result for saving later
276
+ rst["train_times"] = train_times[train_idx]
277
+ if test_times is not None and test_idx is not None:
278
+ rst["test_times"] = test_times[test_idx]
279
+
280
+ # Add time2id to result dict
281
+ rst["time2id"] = time2id
282
+
283
+ return rst
284
+
285
+
286
+ def convert_labels(self, train_labels, test_labels):
287
+ if train_labels:
288
+ label_list = list(set(train_labels).union(set(test_labels)))
289
+ label_list.sort()
290
+ n_labels = len(label_list)
291
+ label2id = dict(zip(label_list, range(n_labels)))
292
+
293
+ logger.info(f"label2id: {label2id}")
294
+
295
+ train_labels = [label2id[label] for label in train_labels]
296
+
297
+ if test_labels:
298
+ test_labels = [label2id[label] for label in test_labels]
299
+
300
+ return train_labels, test_labels
301
+
302
+ def preprocess(
303
+ self,
304
+ raw_train_texts,
305
+ train_labels=None,
306
+ raw_test_texts=None,
307
+ test_labels=None,
308
+ pretrained_WE=True
309
+ ):
310
+ np.random.seed(self.seed)
311
+
312
+ train_texts = list()
313
+ test_texts = list()
314
+ word_counts = Counter()
315
+ doc_counts_counter = Counter()
316
+
317
+ train_labels, test_labels = self.convert_labels(train_labels, test_labels)
318
+
319
+ for text in tqdm(raw_train_texts, desc="loading train texts"):
320
+ tokens = self.tokenizer(text)
321
+ word_counts.update(tokens)
322
+ doc_counts_counter.update(set(tokens))
323
+ parsed_text = ' '.join(tokens)
324
+ train_texts.append(parsed_text)
325
+
326
+ if raw_test_texts:
327
+ for text in tqdm(raw_test_texts, desc="loading test texts"):
328
+ tokens = self.tokenizer(text)
329
+ word_counts.update(tokens)
330
+ doc_counts_counter.update(set(tokens))
331
+ parsed_text = ' '.join(tokens)
332
+ test_texts.append(parsed_text)
333
+
334
+ words, doc_counts = zip(*doc_counts_counter.most_common())
335
+ doc_freqs = np.array(doc_counts) / float(len(train_texts) + len(test_texts))
336
+
337
+ vocab = [word for i, word in enumerate(words) if doc_counts[i] >= self.min_doc_count and doc_freqs[i] <= self.max_doc_freq]
338
+
339
+ # filter vocabulary
340
+ if self.vocab_size is not None:
341
+ vocab = vocab[:self.vocab_size]
342
+
343
+ vocab.sort()
344
+
345
+ train_idx = [i for i, text in enumerate(train_texts) if len(text.split()) >= self.min_term]
346
+ train_idx = np.asarray(train_idx)
347
+
348
+ if raw_test_texts is not None:
349
+ test_idx = [i for i, text in enumerate(test_texts) if len(text.split()) >= self.min_term]
350
+ test_idx = np.asarray(test_idx)
351
+ else:
352
+ test_idx = None
353
+
354
+ # randomly sample
355
+ if self.test_sample_size and raw_test_texts is not None:
356
+ logger.info("sample train and test sets...")
357
+
358
+ train_num = len(train_idx)
359
+ test_num = len(test_idx)
360
+ test_sample_size = min(test_num, self.test_sample_size)
361
+ train_sample_size = int((test_sample_size / self.test_p) * (1 - self.test_p))
362
+ if train_sample_size > train_num:
363
+ test_sample_size = int((train_num / (1 - self.test_p)) * self.test_p)
364
+ train_sample_size = train_num
365
+
366
+ train_idx = train_idx[np.sort(np.random.choice(train_num, train_sample_size, replace=False))]
367
+ test_idx = test_idx[np.sort(np.random.choice(test_num, test_sample_size, replace=False))]
368
+
369
+ logger.info(f"sampled train size: {len(train_idx)}")
370
+ logger.info(f"sampled test size: {len(test_idx)}")
371
+
372
+ train_texts, train_bow = self.parse([train_texts[i] for i in train_idx], vocab)
373
+
374
+ rst = {
375
+ 'vocab': vocab,
376
+ 'train_bow': train_bow,
377
+ "train_texts": train_texts,
378
+ "train_idx": train_idx, # <--- NEW: indices of kept train samples
379
+ }
380
+
381
+ if train_labels:
382
+ rst['train_labels'] = np.asarray(train_labels)[train_idx]
383
+
384
+ logger.info(f"Real vocab size: {len(vocab)}")
385
+ logger.info(f"Real training size: {len(train_texts)} \t avg length: {rst['train_bow'].sum() / len(train_texts):.3f}")
386
+
387
+ if raw_test_texts:
388
+ rst['test_texts'], rst['test_bow'] = self.parse(np.asarray(test_texts)[test_idx].tolist(), vocab)
389
+ rst["test_idx"] = test_idx # <--- NEW: indices of kept test samples
390
+
391
+ if test_labels:
392
+ rst['test_labels'] = np.asarray(test_labels)[test_idx]
393
+
394
+ logger.info(f"Real testing size: {len(rst['test_texts'])} \t avg length: {rst['test_bow'].sum() / len(rst['test_texts']):.3f}")
395
+
396
+ if pretrained_WE:
397
+ rst['word_embeddings'] = make_word_embeddings(vocab)
398
+
399
+ return rst
400
+
401
+ def save(
402
+ self,
403
+ output_dir,
404
+ vocab,
405
+ train_texts,
406
+ train_bow,
407
+ word_embeddings=None,
408
+ train_labels=None,
409
+ test_texts=None,
410
+ test_bow=None,
411
+ test_labels=None,
412
+ train_times=None,
413
+ test_times=None,
414
+ time2id=None # <-- new parameter
415
+ ):
416
+ file_utils.make_dir(output_dir)
417
+
418
+ file_utils.save_text(vocab, f"{output_dir}/vocab.txt")
419
+ file_utils.save_text(train_texts, f"{output_dir}/train_texts.txt")
420
+ scipy.sparse.save_npz(f"{output_dir}/train_bow.npz", scipy.sparse.csr_matrix(train_bow))
421
+
422
+ if word_embeddings is not None:
423
+ scipy.sparse.save_npz(f"{output_dir}/word_embeddings.npz", word_embeddings)
424
+
425
+ if train_labels:
426
+ np.savetxt(f"{output_dir}/train_labels.txt", train_labels, fmt='%i')
427
+
428
+ if train_times is not None:
429
+ np.savetxt(f"{output_dir}/train_times.txt", train_times, fmt='%i')
430
+
431
+ if test_bow is not None:
432
+ scipy.sparse.save_npz(f"{output_dir}/test_bow.npz", scipy.sparse.csr_matrix(test_bow))
433
+
434
+ if test_texts is not None:
435
+ file_utils.save_text(test_texts, f"{output_dir}/test_texts.txt")
436
+
437
+ if test_labels:
438
+ np.savetxt(f"{output_dir}/test_labels.txt", test_labels, fmt='%i')
439
+
440
+ if test_times is not None:
441
+ np.savetxt(f"{output_dir}/test_times.txt", test_times, fmt='%i')
442
+
443
+ # Save time2id mapping if provided
444
+ if time2id is not None:
445
+ with open(f"{output_dir}/time2id.txt", "w", encoding="utf-8") as f:
446
+ json.dump(time2id, f, indent=2)
447
+
backend/datasets/data/download.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import zipfile
3
+ from torchvision.datasets.utils import download_url
4
+ from backend.datasets.utils.logger import Logger
5
+
6
+
7
+ logger = Logger("WARNING")
8
+
9
+
10
+ def download_dataset(dataset_name, cache_path="~/.topmost"):
11
+ cache_path = os.path.expanduser(cache_path)
12
+ raw_filename = f'{dataset_name}.zip'
13
+
14
+ if dataset_name in ['Wikitext-103']:
15
+ # download from Git LFS.
16
+ zipped_dataset_url = f"https://media.githubusercontent.com/media/BobXWu/TopMost/main/data/{raw_filename}"
17
+ else:
18
+ zipped_dataset_url = f"https://raw.githubusercontent.com/BobXWu/TopMost/master/data/{raw_filename}"
19
+
20
+ logger.info(zipped_dataset_url)
21
+
22
+ download_url(zipped_dataset_url, root=cache_path, filename=raw_filename, md5=None)
23
+
24
+ path = f'{cache_path}/{raw_filename}'
25
+ with zipfile.ZipFile(path, 'r') as zip_ref:
26
+ zip_ref.extractall(cache_path)
27
+
28
+ os.remove(path)
29
+
30
+
31
+ if __name__ == '__main__':
32
+ download_dataset('20NG')
backend/datasets/data/file_utils.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+
5
+ def make_dir(path):
6
+ os.makedirs(path, exist_ok=True)
7
+
8
+
9
+ def read_text(path):
10
+ texts = list()
11
+ with open(path, 'r', encoding='utf-8', errors='ignore') as file:
12
+ for line in file:
13
+ texts.append(line.strip())
14
+ return texts
15
+
16
+
17
+ def save_text(texts, path):
18
+ with open(path, 'w', encoding='utf-8') as file:
19
+ for text in texts:
20
+ file.write(text.strip() + '\n')
21
+
22
+
23
+ def read_jsonlist(path):
24
+ data = list()
25
+ with open(path, 'r', encoding='utf-8') as input_file:
26
+ for line in input_file:
27
+ data.append(json.loads(line))
28
+ return data
29
+
30
+
31
+ def save_jsonlist(list_of_json_objects, path, sort_keys=True):
32
+ with open(path, 'w', encoding='utf-8') as output_file:
33
+ for obj in list_of_json_objects:
34
+ output_file.write(json.dumps(obj, sort_keys=sort_keys) + '\n')
35
+
36
+
37
+ def split_text_word(texts):
38
+ texts = [text.split() for text in texts]
39
+ return texts
backend/datasets/dynamic_dataset.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset, DataLoader
3
+ import numpy as np
4
+ import scipy.sparse
5
+ import scipy.io
6
+ from backend.datasets.data import file_utils
7
+
8
+
9
+ class _SequentialDataset(Dataset):
10
+ def __init__(self, bow, times, time_wordfreq):
11
+ super().__init__()
12
+ self.bow = bow
13
+ self.times = times
14
+ self.time_wordfreq = time_wordfreq
15
+
16
+ def __len__(self):
17
+ return len(self.bow)
18
+
19
+ def __getitem__(self, index):
20
+ return_dict = {
21
+ 'bow': self.bow[index],
22
+ 'times': self.times[index],
23
+ 'time_wordfreq': self.time_wordfreq[self.times[index]],
24
+ }
25
+
26
+ return return_dict
27
+
28
+
29
+ class DynamicDataset:
30
+ def __init__(self, dataset_dir, batch_size=200, read_labels=False, use_partition=False, device='cuda', as_tensor=True):
31
+
32
+ self.load_data(dataset_dir, read_labels, use_partition)
33
+
34
+ self.vocab_size = len(self.vocab)
35
+ self.train_size = len(self.train_bow)
36
+ self.num_times = int(self.train_times.max()) + 1 # assuming train_times is a numpy array
37
+ self.train_time_wordfreq = self.get_time_wordfreq(self.train_bow, self.train_times)
38
+
39
+ print('train size: ', len(self.train_bow))
40
+ if use_partition:
41
+ print('test size: ', len(self.test_bow))
42
+ print('vocab size: ', len(self.vocab))
43
+ print('average length: {:.3f}'.format(self.train_bow.sum(1).mean().item()))
44
+ print('num of each time slice: ', self.num_times, np.bincount(self.train_times))
45
+
46
+ if as_tensor:
47
+ self.train_bow = torch.from_numpy(self.train_bow).float().to(device)
48
+ self.train_times = torch.from_numpy(self.train_times).long().to(device)
49
+ self.train_time_wordfreq = torch.from_numpy(self.train_time_wordfreq).float().to(device)
50
+
51
+ if use_partition:
52
+ self.test_bow = torch.from_numpy(self.test_bow).float().to(device)
53
+ self.test_times = torch.from_numpy(self.test_times).long().to(device)
54
+
55
+ self.train_dataset = _SequentialDataset(self.train_bow, self.train_times, self.train_time_wordfreq)
56
+
57
+ if use_partition:
58
+ self.test_dataset = _SequentialDataset(self.test_bow, self.test_times, self.train_time_wordfreq)
59
+
60
+ self.train_dataloader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True)
61
+
62
+ def load_data(self, path, read_labels, use_partition=False):
63
+ self.train_bow = scipy.sparse.load_npz(f'{path}/train_bow.npz').toarray().astype('float32')
64
+ self.train_texts = file_utils.read_text(f'{path}/train_texts.txt')
65
+ self.train_times = np.loadtxt(f'{path}/train_times.txt').astype('int32')
66
+ self.vocab = file_utils.read_text(f'{path}/vocab.txt')
67
+ self.word_embeddings = scipy.sparse.load_npz(f'{path}/word_embeddings.npz').toarray().astype('float32')
68
+
69
+ self.pretrained_WE = self.word_embeddings # preserve compatibility
70
+
71
+ if read_labels:
72
+ self.train_labels = np.loadtxt(f'{path}/train_labels.txt').astype('int32')
73
+
74
+ if use_partition:
75
+ self.test_bow = scipy.sparse.load_npz(f'{path}/test_bow.npz').toarray().astype('float32')
76
+ self.test_texts = file_utils.read_text(f'{path}/test_texts.txt')
77
+ self.test_times = np.loadtxt(f'{path}/test_times.txt').astype('int32')
78
+ if read_labels:
79
+ self.test_labels = np.loadtxt(f'{path}/test_labels.txt').astype('int32')
80
+
81
+ # word frequency at each time slice.
82
+ def get_time_wordfreq(self, bow, times):
83
+ train_time_wordfreq = np.zeros((self.num_times, self.vocab_size))
84
+ for time in range(self.num_times):
85
+ idx = np.where(times == time)[0]
86
+ train_time_wordfreq[time] += bow[idx].sum(0)
87
+ cnt_times = np.bincount(times)
88
+
89
+ train_time_wordfreq = train_time_wordfreq / cnt_times[:, np.newaxis]
90
+ return train_time_wordfreq
backend/datasets/preprocess.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import numpy as np
4
+ from collections import OrderedDict
5
+ import tempfile
6
+ import gensim.downloader
7
+ from tqdm import tqdm
8
+ from backend.datasets.utils.logger import Logger
9
+ import scipy.sparse
10
+ from gensim.models.phrases import Phrases, Phraser
11
+ from typing import List, Union
12
+ from octis.preprocessing.preprocessing import Preprocessing
13
+
14
+ logger = Logger("WARNING")
15
+
16
+ class Preprocessor:
17
+ def __init__(self,
18
+ docs_jsonl_path: str,
19
+ output_folder: str,
20
+ use_partition: bool = False,
21
+ use_bigrams: bool = False,
22
+ min_count_bigram: int = 5,
23
+ threshold_bigram: int = 10,
24
+ remove_punctuation: bool = True,
25
+ lemmatize: bool = True,
26
+ stopword_list: Union[str, List[str]] = None,
27
+ min_chars: int = 3,
28
+ min_words_docs: int = 10,
29
+ min_df: Union[int, float] = 0.0,
30
+ max_df: Union[int, float] = 1.0,
31
+ max_features: int = None,
32
+ language: str = 'english'):
33
+
34
+ self.docs_jsonl_path = docs_jsonl_path
35
+ self.output_folder = output_folder
36
+ self.use_partition = use_partition
37
+ self.use_bigrams = use_bigrams
38
+ self.min_count_bigram = min_count_bigram
39
+ self.threshold_bigram = threshold_bigram
40
+
41
+ os.makedirs(self.output_folder, exist_ok=True)
42
+
43
+ self.preprocessing_params = {
44
+ 'remove_punctuation': remove_punctuation,
45
+ 'lemmatize': lemmatize,
46
+ 'stopword_list': stopword_list,
47
+ 'min_chars': min_chars,
48
+ 'min_words_docs': min_words_docs,
49
+ 'min_df': min_df,
50
+ 'max_df': max_df,
51
+ 'max_features': max_features,
52
+ 'language': language
53
+ }
54
+ self.preprocessor_octis = Preprocessing(**self.preprocessing_params)
55
+
56
+ def _load_data_to_temp_files(self):
57
+ """Loads data from JSONL and writes to temporary files for OCTIS preprocessor."""
58
+ raw_texts = []
59
+ raw_timestamps = []
60
+ raw_labels = []
61
+ has_labels = False
62
+
63
+ with open(self.docs_jsonl_path, 'r', encoding='utf-8') as f:
64
+ for line in f:
65
+ data = json.loads(line.strip())
66
+ # Remove newlines from text
67
+ clean_text = data.get('text', '').replace('\n', ' ').replace('\r', ' ')
68
+ clean_text = " ".join(clean_text.split())
69
+ raw_texts.append(clean_text)
70
+ raw_timestamps.append(data.get('timestamp', ''))
71
+ label = data.get('label', '')
72
+ if label:
73
+ has_labels = True
74
+ raw_labels.append(label)
75
+
76
+ # Create temporary files
77
+ temp_dir = tempfile.mkdtemp()
78
+ temp_docs_path = os.path.join(temp_dir, "temp_docs.txt")
79
+ temp_labels_path = None
80
+
81
+ with open(temp_docs_path, 'w', encoding='utf-8') as f_docs:
82
+ for text in raw_texts:
83
+ f_docs.write(f"{text}\n")
84
+
85
+ if has_labels:
86
+ temp_labels_path = os.path.join(temp_dir, "temp_labels.txt")
87
+ with open(temp_labels_path, 'w', encoding='utf-8') as f_labels:
88
+ for label in raw_labels:
89
+ f_labels.write(f"{label}\n")
90
+
91
+ print(f"Loaded {len(raw_texts)} raw documents and created temporary files in {temp_dir}.")
92
+ return raw_texts, raw_timestamps, raw_labels, temp_docs_path, temp_labels_path, temp_dir
93
+
94
+ def _make_word_embeddings(self, vocab):
95
+ """
96
+ Generates word embeddings for the given vocabulary using GloVe.
97
+ For n-grams (e.g., "wordA_wordB", "wordX_wordY_wordZ" for n>=2),
98
+ the resultant embedding is the sum of the embeddings of its constituent
99
+ single words (wordA + wordB + ...).
100
+ """
101
+ print("Loading GloVe word embeddings...")
102
+ glove_vectors = gensim.downloader.load('glove-wiki-gigaword-200')
103
+
104
+ # Initialize word_embeddings matrix with zeros.
105
+ # This ensures that words not found (single or n-gram constituents)
106
+ # will have a zero vector embedding.
107
+ word_embeddings = np.zeros((len(vocab), glove_vectors.vectors.shape[1]), dtype=np.float32)
108
+
109
+ num_found = 0
110
+
111
+ try:
112
+ # Using a set for key_word_list for O(1) average time complexity lookup
113
+ key_word_list = set(glove_vectors.index_to_key)
114
+ except AttributeError: # For older gensim versions
115
+ key_word_list = set(glove_vectors.index2word)
116
+
117
+ print("Generating word embeddings for vocabulary (including n-grams)...")
118
+ for i, word in enumerate(tqdm(vocab, desc="Processing vocabulary words")):
119
+ if '_' in word: # Check if it's a potential n-gram (n >= 2)
120
+ parts = word.split('_')
121
+
122
+ # Check if *all* constituent words are present in GloVe
123
+ all_parts_in_glove = True
124
+ for part in parts:
125
+ if part not in key_word_list:
126
+ all_parts_in_glove = False
127
+ break # One part not found, stop checking
128
+
129
+ if all_parts_in_glove:
130
+ # If all parts are found, sum their embeddings
131
+ resultant_vector = np.zeros(glove_vectors.vectors.shape[1], dtype=np.float32)
132
+ for part in parts:
133
+ resultant_vector += glove_vectors[part]
134
+
135
+ word_embeddings[i] = resultant_vector
136
+ num_found += 1
137
+ # Else: one or more constituent words not found, embedding remains zero
138
+ else: # It's a single word (n=1)
139
+ if word in key_word_list:
140
+ word_embeddings[i] = glove_vectors[word]
141
+ num_found += 1
142
+ # Else: single word not found, embedding remains zero
143
+
144
+ logger.info(f'Number of found embeddings (including n-grams): {num_found}/{len(vocab)}')
145
+ return word_embeddings # Return as dense NumPy array
146
+
147
+
148
+ def _save_doc_length_stats(self, filepath: str, output_path: str):
149
+ doc_lengths = []
150
+ try:
151
+ with open(filepath, 'r', encoding='utf-8') as f:
152
+ for line in f:
153
+ doc = line.strip()
154
+ if doc:
155
+ doc_lengths.append(len(doc))
156
+ except Exception as e:
157
+ print(f"Error processing '{filepath}': {e}")
158
+ return
159
+
160
+ if not doc_lengths:
161
+ print(f"No documents found in '{filepath}'.")
162
+ return
163
+
164
+ stats = {
165
+ "avg_len": float(np.mean(doc_lengths)),
166
+ "std_len": float(np.std(doc_lengths)),
167
+ "max_len": int(np.max(doc_lengths)),
168
+ "min_len": int(np.min(doc_lengths)),
169
+ "num_docs": int(len(doc_lengths))
170
+ }
171
+
172
+ with open(output_path, 'w', encoding='utf-8') as f:
173
+ json.dump(stats, f, indent=4)
174
+ print(f"Saved document length stats to: {output_path}")
175
+
176
+
177
+ def preprocess(self):
178
+ print("Loading data and creating temporary files for OCTIS...")
179
+ _, raw_timestamps, _, temp_docs_path, temp_labels_path, temp_dir = \
180
+ self._load_data_to_temp_files()
181
+
182
+ print("Starting OCTIS pre-processing using file paths and specified parameters...")
183
+ octis_dataset = self.preprocessor_octis.preprocess_dataset(
184
+ documents_path=temp_docs_path,
185
+ labels_path=temp_labels_path
186
+ )
187
+
188
+ # Clean up temporary files immediately
189
+ os.remove(temp_docs_path)
190
+ if temp_labels_path:
191
+ os.remove(temp_labels_path)
192
+ os.rmdir(temp_dir)
193
+ print(f"Temporary files in {temp_dir} cleaned up.")
194
+
195
+ # --- Proxy: Save __original_indexes and then manually load it ---
196
+ temp_indexes_dir = tempfile.mkdtemp()
197
+ temp_indexes_file = os.path.join(temp_indexes_dir, "temp_original_indexes.txt")
198
+
199
+ print(f"Saving __original_indexes to {temp_indexes_file}...")
200
+ octis_dataset._save_document_indexes(temp_indexes_file)
201
+
202
+ # Manually load the indexes from the file
203
+ original_indexes_after_octis = []
204
+ with open(temp_indexes_file, 'r') as f_indexes:
205
+ for line in f_indexes:
206
+ original_indexes_after_octis.append(int(line.strip())) # Read as int
207
+
208
+ # Clean up the temporary indexes file and its directory
209
+ os.remove(temp_indexes_file)
210
+ os.rmdir(temp_indexes_dir)
211
+ print("Temporary indexes file cleaned up.")
212
+ # --- End Proxy ---
213
+
214
+ # Get processed data from OCTIS Dataset object
215
+ processed_corpus_octis_list = octis_dataset.get_corpus() # List of list of tokens
216
+ processed_labels_octis = octis_dataset.get_labels() # List of labels
217
+
218
+ print("Max index in original_indexes_after_octis:", max(original_indexes_after_octis))
219
+ print("Length of raw_timestamps:", len(raw_timestamps))
220
+
221
+ # Filter timestamps based on documents that survived OCTIS preprocessing
222
+ filtered_timestamps = [raw_timestamps[i] for i in original_indexes_after_octis]
223
+
224
+ print(f"OCTIS preprocessing complete. {len(processed_corpus_octis_list)} documents remaining.")
225
+
226
+ if self.use_bigrams:
227
+ print("Generating bigrams with Gensim...")
228
+ phrases = Phrases(processed_corpus_octis_list, min_count=self.min_count_bigram, threshold=self.threshold_bigram)
229
+ bigram_phraser = Phraser(phrases)
230
+ bigrammed_corpus_list = [bigram_phraser[doc] for doc in processed_corpus_octis_list]
231
+ print("Bigram generation complete.")
232
+ else:
233
+ print("Skipping bigram generation as 'use_bigrams' is False.")
234
+ bigrammed_corpus_list = processed_corpus_octis_list # Use the original processed list
235
+
236
+
237
+ # Convert back to list of strings for easier handling if needed later, but keep as list of lists for BOW
238
+ bigrammed_texts_for_file = [" ".join(doc) for doc in bigrammed_corpus_list]
239
+ print("Bigram generation complete.")
240
+
241
+ # Build Vocabulary from OCTIS output (after bigrams)
242
+ # We need a flat list of all tokens to build the vocabulary
243
+ all_tokens = [token for doc in bigrammed_corpus_list for token in doc]
244
+ vocab = sorted(list(set(all_tokens))) # Sorted unique words form the vocabulary
245
+ word_to_id = {word: i for i, word in enumerate(vocab)}
246
+
247
+ # Create BOW matrix manually
248
+ print("Creating Bag-of-Words representations...")
249
+ rows, cols, data = [], [], []
250
+ for i, doc_tokens in enumerate(bigrammed_corpus_list):
251
+ doc_word_counts = {}
252
+ for token in doc_tokens:
253
+ if token in word_to_id: # Ensure token is in our final vocab
254
+ doc_word_counts[word_to_id[token]] = doc_word_counts.get(word_to_id[token], 0) + 1
255
+ for col_id, count in doc_word_counts.items():
256
+ rows.append(i)
257
+ cols.append(col_id)
258
+ data.append(count)
259
+
260
+ # Shape is (num_documents, vocab_size)
261
+ bow_matrix = scipy.sparse.csc_matrix((data, (rows, cols)), shape=(len(bigrammed_corpus_list), len(vocab)))
262
+ print("Bag-of-Words complete.")
263
+
264
+ # Handle partitioning if required
265
+ if self.use_partition:
266
+ num_docs = len(bigrammed_corpus_list)
267
+ train_size = int(0.8 * num_docs)
268
+
269
+ train_texts = bigrammed_texts_for_file[:train_size]
270
+ train_bow_matrix = bow_matrix[:train_size]
271
+ train_timestamps = filtered_timestamps[:train_size]
272
+ train_labels = processed_labels_octis[:train_size] if processed_labels_octis else []
273
+
274
+ test_texts = bigrammed_texts_for_file[train_size:]
275
+ test_bow_matrix = bow_matrix[train_size:]
276
+ test_timestamps = filtered_timestamps[train_size:]
277
+ test_labels = processed_labels_octis[train_size:] if processed_labels_octis else []
278
+
279
+ else:
280
+ train_texts = bigrammed_texts_for_file
281
+ train_bow_matrix = bow_matrix
282
+ train_timestamps = filtered_timestamps
283
+ train_labels = processed_labels_octis
284
+ test_texts = []
285
+ test_timestamps = []
286
+ test_labels = []
287
+
288
+ # Generate word embeddings using the provided function
289
+ word_embeddings = self._make_word_embeddings(vocab)
290
+
291
+ # Process timestamps to 0, 1, 2...T and create time2id.txt
292
+ print("Processing timestamps...")
293
+ unique_timestamps = sorted(list(set(train_timestamps + test_timestamps)))
294
+ time_to_id = {timestamp: i for i, timestamp in enumerate(unique_timestamps)}
295
+
296
+ train_times_ids = [time_to_id[ts] for ts in train_timestamps]
297
+ test_times_ids = [time_to_id[ts] for ts in test_timestamps] if self.use_partition else []
298
+ print("Timestamps processed.")
299
+
300
+ # Save files
301
+ print(f"Saving preprocessed files to {self.output_folder}...")
302
+
303
+ # 1. vocab.txt
304
+ with open(os.path.join(self.output_folder, "vocab.txt"), "w", encoding="utf-8") as f:
305
+ for word in vocab:
306
+ f.write(f"{word}\n")
307
+
308
+ # 2. train_texts.txt
309
+ train_text_path = os.path.join(self.output_folder, "train_texts.txt")
310
+ with open(train_text_path, "w", encoding="utf-8") as f:
311
+ for doc in train_texts:
312
+ f.write(f"{doc}\n")
313
+
314
+ # Save document length stats
315
+ doc_stats_path = os.path.join(self.output_folder, "length_stats.json")
316
+ self._save_doc_length_stats(train_text_path, doc_stats_path)
317
+
318
+ # 3. train_bow.npz
319
+ scipy.sparse.save_npz(os.path.join(self.output_folder, "train_bow.npz"), train_bow_matrix)
320
+
321
+ # 4. word_embeddings.npz
322
+ sparse_word_embeddings = scipy.sparse.csr_matrix(word_embeddings)
323
+ scipy.sparse.save_npz(os.path.join(self.output_folder, "word_embeddings.npz"), sparse_word_embeddings)
324
+
325
+ # 5. train_labels.txt (if labels exist)
326
+ if train_labels:
327
+ with open(os.path.join(self.output_folder, "train_labels.txt"), "w", encoding="utf-8") as f:
328
+ for label in train_labels:
329
+ f.write(f"{label}\n")
330
+
331
+ # 6. train_times.txt
332
+ with open(os.path.join(self.output_folder, "train_times.txt"), "w", encoding="utf-8") as f:
333
+ for time_id in train_times_ids:
334
+ f.write(f"{time_id}\n")
335
+
336
+ # Files for test set (if use_partition=True)
337
+ if self.use_partition:
338
+ # 7. test_bow.npz
339
+ scipy.sparse.save_npz(os.path.join(self.output_folder, "test_bow.npz"), test_bow_matrix)
340
+
341
+ # 8. test_texts.txt
342
+ with open(os.path.join(self.output_folder, "test_texts.txt"), "w", encoding="utf-8") as f:
343
+ for doc in test_texts:
344
+ f.write(f"{doc}\n")
345
+
346
+ # 9. test_labels.txt (if labels exist)
347
+ if test_labels:
348
+ with open(os.path.join(self.output_folder, "test_labels.txt"), "w", encoding="utf-8") as f:
349
+ for label in test_labels:
350
+ f.write(f"{label}\n")
351
+
352
+ # 10. test_times.txt
353
+ with open(os.path.join(self.output_folder, "test_times.txt"), "w", encoding="utf-8") as f:
354
+ for time_id in test_times_ids:
355
+ f.write(f"{time_id}\n")
356
+
357
+ # 11. time2id.txt
358
+ sorted_time_to_id = OrderedDict(sorted(time_to_id.items(), key=lambda item: item[1]))
359
+ with open(os.path.join(self.output_folder, "time2id.txt"), "w", encoding="utf-8") as f:
360
+ json.dump(sorted_time_to_id, f, indent=4)
361
+
362
+ print("All files saved successfully.")
backend/datasets/utils/_utils.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from backend.datasets.data import file_utils
3
+
4
+
5
+ def get_top_words(beta, vocab, num_top_words, verbose=False):
6
+ topic_str_list = list()
7
+ for i, topic_dist in enumerate(beta):
8
+ topic_words = np.array(vocab)[np.argsort(topic_dist)][:-(num_top_words + 1):-1]
9
+ topic_str = ' '.join(topic_words)
10
+ topic_str_list.append(topic_str)
11
+ if verbose:
12
+ print('Topic {}: {}'.format(i, topic_str))
13
+
14
+ return topic_str_list
15
+
16
+
17
+ def get_stopwords_set(stopwords=[]):
18
+ from backend.datasets.data.download import download_dataset
19
+
20
+ if stopwords == 'English':
21
+ from gensim.parsing.preprocessing import STOPWORDS as stopwords
22
+
23
+ elif stopwords in ['mallet', 'snowball']:
24
+ download_dataset('stopwords', cache_path='./')
25
+ path = f'./stopwords/{stopwords}_stopwords.txt'
26
+ stopwords = file_utils.read_text(path)
27
+
28
+ stopword_set = frozenset(stopwords)
29
+
30
+ return stopword_set
31
+
32
+
33
+ if __name__ == '__main__':
34
+ print(list(get_stopwords_set('English'))[:10])
35
+ print(list(get_stopwords_set('mallet'))[:10])
36
+ print(list(get_stopwords_set('snowball'))[:10])
37
+ print(list(get_stopwords_set())[:10])
backend/datasets/utils/logger.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+
4
+ class Logger:
5
+ def __init__(self, level):
6
+ self.logger = logging.getLogger('TopMost')
7
+ self.set_level(level)
8
+ self._add_handler()
9
+ self.logger.propagate = False
10
+
11
+ def info(self, message):
12
+ self.logger.info(f"{message}")
13
+
14
+ def warning(self, message):
15
+ self.logger.warning(f"WARNING: {message}")
16
+
17
+ def set_level(self, level):
18
+ levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
19
+ if level in levels:
20
+ self.logger.setLevel(level)
21
+
22
+ def _add_handler(self):
23
+ sh = logging.StreamHandler()
24
+ sh.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(message)s'))
25
+ self.logger.addHandler(sh)
26
+
27
+ # Remove duplicate handlers
28
+ if len(self.logger.handlers) > 1:
29
+ self.logger.handlers = [self.logger.handlers[0]]
backend/evaluation/CoherenceModel_ttc.py ADDED
@@ -0,0 +1,862 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import multiprocessing as mp
3
+ from collections import namedtuple
4
+
5
+ import numpy as np
6
+
7
+ from gensim import interfaces, matutils
8
+ from gensim import utils
9
+ from gensim.topic_coherence import (
10
+ segmentation, probability_estimation,
11
+ direct_confirmation_measure, indirect_confirmation_measure,
12
+ aggregation,
13
+ )
14
+ from gensim.topic_coherence.probability_estimation import unique_ids_from_segments
15
+
16
+ # Set up logging for this module
17
+ logger = logging.getLogger(__name__)
18
+
19
+ # Define sets for categorizing coherence measures based on their probability estimation method
20
+ BOOLEAN_DOCUMENT_BASED = {'u_mass'}
21
+ SLIDING_WINDOW_BASED = {'c_v', 'c_uci', 'c_npmi', 'c_w2v'}
22
+
23
+ # Create a namedtuple to define the structure of a coherence measure pipeline
24
+ # Each pipeline consists of a segmentation (seg), probability estimation (prob),
25
+ # confirmation measure (conf), and aggregation (aggr) function.
26
+ _make_pipeline = namedtuple('Coherence_Measure', 'seg, prob, conf, aggr')
27
+
28
+ # Define the supported coherence measures and their respective pipeline components
29
+ COHERENCE_MEASURES = {
30
+ 'u_mass': _make_pipeline(
31
+ segmentation.s_one_pre,
32
+ probability_estimation.p_boolean_document,
33
+ direct_confirmation_measure.log_conditional_probability,
34
+ aggregation.arithmetic_mean
35
+ ),
36
+ 'c_v': _make_pipeline(
37
+ segmentation.s_one_set,
38
+ probability_estimation.p_boolean_sliding_window,
39
+ indirect_confirmation_measure.cosine_similarity,
40
+ aggregation.arithmetic_mean
41
+ ),
42
+ 'c_w2v': _make_pipeline(
43
+ segmentation.s_one_set,
44
+ probability_estimation.p_word2vec,
45
+ indirect_confirmation_measure.word2vec_similarity,
46
+ aggregation.arithmetic_mean
47
+ ),
48
+ 'c_uci': _make_pipeline(
49
+ segmentation.s_one_one,
50
+ probability_estimation.p_boolean_sliding_window,
51
+ direct_confirmation_measure.log_ratio_measure,
52
+ aggregation.arithmetic_mean
53
+ ),
54
+ 'c_npmi': _make_pipeline(
55
+ segmentation.s_one_one,
56
+ probability_estimation.p_boolean_sliding_window,
57
+ direct_confirmation_measure.log_ratio_measure,
58
+ aggregation.arithmetic_mean
59
+ ),
60
+ }
61
+
62
+ # Define default sliding window sizes for different coherence measures
63
+ SLIDING_WINDOW_SIZES = {
64
+ 'c_v': 110,
65
+ 'c_w2v': 5,
66
+ 'c_uci': 10,
67
+ 'c_npmi': 10,
68
+ 'u_mass': None # u_mass does not use a sliding window
69
+ }
70
+
71
+
72
+ class CoherenceModel_ttc(interfaces.TransformationABC):
73
+ """Objects of this class allow for building and maintaining a model for topic coherence.
74
+
75
+ Examples
76
+ ---------
77
+ One way of using this feature is through providing a trained topic model. A dictionary has to be explicitly provided
78
+ if the model does not contain a dictionary already
79
+
80
+ .. sourcecode:: pycon
81
+
82
+ >>> from gensim.test.utils import common_corpus, common_dictionary
83
+ >>> from gensim.models.ldamodel import LdaModel
84
+ >>> # Assuming CoherenceModel_ttc is imported or defined in the current scope
85
+ >>> # from your_module import CoherenceModel_ttc # if saved in a file
86
+ >>>
87
+ >>> model = LdaModel(common_corpus, 5, common_dictionary)
88
+ >>>
89
+ >>> cm = CoherenceModel_ttc(model=model, corpus=common_corpus, coherence='u_mass')
90
+ >>> coherence = cm.get_coherence() # get coherence value
91
+
92
+ Another way of using this feature is through providing tokenized topics such as:
93
+
94
+ .. sourcecode:: pycon
95
+
96
+ >>> from gensim.test.utils import common_corpus, common_dictionary
97
+ >>> # Assuming CoherenceModel_ttc is imported or defined in the current scope
98
+ >>> # from your_module import CoherenceModel_ttc # if saved in a file
99
+ >>> topics = [
100
+ ... ['human', 'computer', 'system', 'interface'],
101
+ ... ['graph', 'minors', 'trees', 'eps']
102
+ ... ]
103
+ >>>
104
+ >>> cm = CoherenceModel_ttc(topics=topics, corpus=common_corpus, dictionary=common_dictionary, coherence='u_mass')
105
+ >>> coherence = cm.get_coherence() # get coherence value
106
+
107
+ """
108
+ def __init__(self, model=None, topics=None, texts=None, corpus=None, dictionary=None,
109
+ window_size=None, keyed_vectors=None, coherence='c_v', topn=20, processes=-1):
110
+ """
111
+ Initializes the CoherenceModel_ttc.
112
+
113
+ Parameters
114
+ ----------
115
+ model : :class:`~gensim.models.basemodel.BaseTopicModel`, optional
116
+ Pre-trained topic model. Should be provided if `topics` is not provided.
117
+ Supports models that implement the `get_topics` method.
118
+ topics : list of list of str, optional
119
+ List of tokenized topics. If provided, `dictionary` must also be provided.
120
+ texts : list of list of str, optional
121
+ Tokenized texts, needed for coherence models that use sliding window based (e.g., `c_v`, `c_uci`, `c_npmi`).
122
+ corpus : iterable of list of (int, number), optional
123
+ Corpus in Bag-of-Words format.
124
+ dictionary : :class:`~gensim.corpora.dictionary.Dictionary`, optional
125
+ Gensim dictionary mapping of id word to create corpus.
126
+ If `model.id2word` is present and `dictionary` is None, `model.id2word` will be used.
127
+ window_size : int, optional
128
+ The size of the window to be used for coherence measures using boolean sliding window as their
129
+ probability estimator. For 'u_mass' this doesn't matter.
130
+ If None, default window sizes from `SLIDING_WINDOW_SIZES` are used.
131
+ keyed_vectors : :class:`~gensim.models.keyedvectors.KeyedVectors`, optional
132
+ Pre-trained word embeddings (e.g., Word2Vec model) for 'c_w2v' coherence.
133
+ coherence : {'u_mass', 'c_v', 'c_uci', 'c_npmi', 'c_w2v'}, optional
134
+ Coherence measure to be used.
135
+ 'u_mass' requires `corpus` (or `texts` which will be converted to corpus).
136
+ 'c_v', 'c_uci', 'c_npmi', 'c_w2v' require `texts`.
137
+ topn : int, optional
138
+ Integer corresponding to the number of top words to be extracted from each topic. Defaults to 20.
139
+ processes : int, optional
140
+ Number of processes to use for probability estimation phase. Any value less than 1 will be interpreted as
141
+ `num_cpus - 1`. Defaults to -1.
142
+ """
143
+ # Ensure either a model or explicit topics are provided
144
+ if model is None and topics is None:
145
+ raise ValueError("One of 'model' or 'topics' has to be provided.")
146
+ # If topics are provided, a dictionary is mandatory to convert tokens to IDs
147
+ elif topics is not None and dictionary is None:
148
+ raise ValueError("Dictionary has to be provided if 'topics' are to be used.")
149
+
150
+ self.keyed_vectors = keyed_vectors
151
+ # Ensure a data source (keyed_vectors, texts, or corpus) is provided for coherence calculation
152
+ if keyed_vectors is None and texts is None and corpus is None:
153
+ raise ValueError("One of 'texts', 'corpus', or 'keyed_vectors' has to be provided.")
154
+
155
+ # Determine the dictionary to use
156
+ if dictionary is None:
157
+ # If no explicit dictionary, try to use the model's dictionary
158
+ if isinstance(model.id2word, utils.FakeDict):
159
+ # If model's id2word is a FakeDict, it means no proper dictionary is associated
160
+ raise ValueError(
161
+ "The associated dictionary should be provided with the corpus or 'id2word'"
162
+ " for topic model should be set as the associated dictionary.")
163
+ else:
164
+ self.dictionary = model.id2word
165
+ else:
166
+ self.dictionary = dictionary
167
+
168
+ # Store coherence type and window size
169
+ self.coherence = coherence
170
+ self.window_size = window_size
171
+ if self.window_size is None:
172
+ # Use default window size if not specified
173
+ self.window_size = SLIDING_WINDOW_SIZES[self.coherence]
174
+
175
+ # Store texts and corpus
176
+ self.texts = texts
177
+ self.corpus = corpus
178
+
179
+ # Validate inputs based on coherence type
180
+ if coherence in BOOLEAN_DOCUMENT_BASED:
181
+ # For document-based measures (e.g., u_mass), corpus is preferred
182
+ if utils.is_corpus(corpus)[0]:
183
+ self.corpus = corpus
184
+ elif self.texts is not None:
185
+ # If texts are provided, convert them to corpus format
186
+ self.corpus = [self.dictionary.doc2bow(text) for text in self.texts]
187
+ else:
188
+ raise ValueError(
189
+ "Either 'corpus' with 'dictionary' or 'texts' should "
190
+ "be provided for %s coherence." % coherence)
191
+
192
+ elif coherence == 'c_w2v' and keyed_vectors is not None:
193
+ # For c_w2v, keyed_vectors are needed
194
+ pass
195
+ elif coherence in SLIDING_WINDOW_BASED:
196
+ # For sliding window-based measures, texts are required
197
+ if self.texts is None:
198
+ raise ValueError("'texts' should be provided for %s coherence." % coherence)
199
+ else:
200
+ # Raise error if coherence type is not supported
201
+ raise ValueError("%s coherence is not currently supported." % coherence)
202
+
203
+ self._topn = topn
204
+ self._model = model
205
+ self._accumulator = None # Cached accumulator for probability estimation
206
+ self._topics = None # Store topics internally
207
+ self.topics = topics # Call the setter to initialize topics and accumulator state
208
+
209
+ # Determine the number of processes to use for parallelization
210
+ self.processes = processes if processes >= 1 else max(1, mp.cpu_count() - 1)
211
+
212
+ @classmethod
213
+ def for_models(cls, models, dictionary, topn=20, **kwargs):
214
+ """
215
+ Initialize a CoherenceModel_ttc with estimated probabilities for all of the given models.
216
+ This method extracts topics from each model and then uses `for_topics`.
217
+
218
+ Parameters
219
+ ----------
220
+ models : list of :class:`~gensim.models.basemodel.BaseTopicModel`
221
+ List of models to evaluate coherence of. Each model should implement
222
+ the `get_topics` method.
223
+ dictionary : :class:`~gensim.corpora.dictionary.Dictionary`
224
+ Gensim dictionary mapping of id word.
225
+ topn : int, optional
226
+ Integer corresponding to the number of top words to be extracted from each topic. Defaults to 20.
227
+ kwargs : object
228
+ Additional arguments passed to the `CoherenceModel_ttc` constructor (e.g., `corpus`, `texts`, `coherence`).
229
+
230
+ Returns
231
+ -------
232
+ :class:`~gensim.models.coherencemodel.CoherenceModel`
233
+ CoherenceModel_ttc instance with estimated probabilities for all given models.
234
+
235
+ Example
236
+ -------
237
+ .. sourcecode:: pycon
238
+
239
+ >>> from gensim.test.utils import common_corpus, common_dictionary
240
+ >>> from gensim.models.ldamodel import LdaModel
241
+ >>> # from your_module import CoherenceModel_ttc
242
+ >>>
243
+ >>> m1 = LdaModel(common_corpus, 3, common_dictionary)
244
+ >>> m2 = LdaModel(common_corpus, 5, common_dictionary)
245
+ >>>
246
+ >>> cm = CoherenceModel_ttc.for_models([m1, m2], common_dictionary, corpus=common_corpus, coherence='u_mass')
247
+ >>> # To get coherences for each model:
248
+ >>> # model_coherences = cm.compare_model_topics([
249
+ >>> # CoherenceModel_ttc._get_topics_from_model(m1, topn=cm.topn),
250
+ >>> # CoherenceModel_ttc._get_topics_from_model(m2, topn=cm.topn)
251
+ >>> # ])
252
+ """
253
+ # Extract top words as lists for each model's topics
254
+ topics = [cls.top_topics_as_word_lists(model, dictionary, topn) for model in models]
255
+ kwargs['dictionary'] = dictionary
256
+ kwargs['topn'] = topn
257
+ # Use for_topics to initialize the coherence model with these topics
258
+ return cls.for_topics(topics, **kwargs)
259
+
260
+ @staticmethod
261
+ def top_topics_as_word_lists(model, dictionary, topn=20):
262
+ """
263
+ Get `topn` topics from a model as lists of words.
264
+
265
+ Parameters
266
+ ----------
267
+ model : :class:`~gensim.models.basemodel.BaseTopicModel`
268
+ Pre-trained topic model.
269
+ dictionary : :class:`~gensim.corpora.dictionary.Dictionary`
270
+ Gensim dictionary mapping of id word.
271
+ topn : int, optional
272
+ Integer corresponding to the number of top words to be extracted from each topic. Defaults to 20.
273
+
274
+ Returns
275
+ -------
276
+ list of list of str
277
+ Top topics in list-of-list-of-words format.
278
+ """
279
+ # Ensure id2token mapping exists in the dictionary
280
+ if not dictionary.id2token:
281
+ dictionary.id2token = {v: k for k, v in dictionary.token2id.items()}
282
+
283
+ str_topics = []
284
+ for topic_distribution in model.get_topics():
285
+ # Get the indices of the topN words based on their probabilities
286
+ bestn_indices = matutils.argsort(topic_distribution, topn=topn, reverse=True)
287
+ # Convert word IDs back to words using the dictionary
288
+ best_words = [dictionary.id2token[_id] for _id in bestn_indices]
289
+ str_topics.append(best_words)
290
+ return str_topics
291
+
292
+ @classmethod
293
+ def for_topics(cls, topics_as_topn_terms, **kwargs):
294
+ """
295
+ Initialize a CoherenceModel_ttc with estimated probabilities for all of the given topics.
296
+ This is useful when you have raw topics (list of lists of words) and not a Gensim model object.
297
+
298
+ Parameters
299
+ ----------
300
+ topics_as_topn_terms : list of list of str
301
+ Each element in the top-level list should be a list of top-N words, one per topic.
302
+ For example: `[['word1', 'word2'], ['word3', 'word4']]`.
303
+
304
+ Returns
305
+ -------
306
+ :class:`~gensim.models.coherencemodel.CoherenceModel`
307
+ CoherenceModel_ttc with estimated probabilities for the given topics.
308
+ """
309
+ if not topics_as_topn_terms:
310
+ raise ValueError("len(topics_as_topn_terms) must be > 0.")
311
+ if any(len(topic_list) == 0 for topic_list in topics_as_topn_terms):
312
+ raise ValueError("Found an empty topic listing in `topics_as_topn_terms`.")
313
+
314
+ # Determine the maximum 'topn' value among the provided topics
315
+ # This will be used to initialize the CoherenceModel_ttc correctly for probability estimation
316
+ actual_topn_in_data = 0
317
+ for topic_list in topics_as_topn_terms:
318
+ for topic in topic_list:
319
+ actual_topn_in_data = max(actual_topn_in_data, len(topic))
320
+
321
+ # Use the provided 'topn' from kwargs, or the determined 'actual_topn_in_data',
322
+ # ensuring it's not greater than the actual data available.
323
+ # This allows for precomputing probabilities for a wider set of words if needed.
324
+ topn_for_prob_estimation = min(kwargs.pop('topn', actual_topn_in_data), actual_topn_in_data)
325
+
326
+ # Flatten all topics into a single "super topic" for initial probability estimation.
327
+ # This ensures that all words relevant to *any* topic in the comparison set
328
+ # are included in the accumulator.
329
+ super_topic = utils.flatten(topics_as_topn_terms)
330
+
331
+ logger.info(
332
+ "Number of relevant terms for all %d models (or topic sets): %d",
333
+ len(topics_as_topn_terms), len(super_topic))
334
+
335
+ # Initialize CoherenceModel_ttc with the super topic to pre-estimate probabilities
336
+ # for all relevant words across all models.
337
+ # We pass `topics=[super_topic]` and `topn=len(super_topic)` to ensure all words
338
+ # are considered during the probability estimation phase.
339
+ cm = CoherenceModel_ttc(topics=[super_topic], topn=len(super_topic), **kwargs)
340
+ cm.estimate_probabilities() # Perform the actual probability estimation
341
+
342
+ # After estimation, set the 'topn' back to the desired value for coherence calculation.
343
+ cm.topn = topn_for_prob_estimation
344
+ return cm
345
+
346
+ def __str__(self):
347
+ """Returns a string representation of the coherence measure pipeline."""
348
+ return str(self.measure)
349
+
350
+ @property
351
+ def model(self):
352
+ """
353
+ Get the current topic model used by the instance.
354
+
355
+ Returns
356
+ -------
357
+ :class:`~gensim.models.basemodel.BaseTopicModel`
358
+ The currently set topic model.
359
+ """
360
+ return self._model
361
+
362
+ @model.setter
363
+ def model(self, model):
364
+ """
365
+ Set the topic model for the instance. When a new model is set,
366
+ it triggers an update of the internal topics and checks if the accumulator needs recomputing.
367
+
368
+ Parameters
369
+ ----------
370
+ model : :class:`~gensim.models.basemodel.BaseTopicModel`
371
+ The new topic model to set.
372
+ """
373
+ self._model = model
374
+ if model is not None:
375
+ new_topics = self._get_topics() # Get topics from the new model
376
+ self._update_accumulator(new_topics) # Check and update accumulator if needed
377
+ self._topics = new_topics # Store the new topics
378
+
379
+ @property
380
+ def topn(self):
381
+ """
382
+ Get the number of top words (`_topn`) used for coherence calculation.
383
+
384
+ Returns
385
+ -------
386
+ int
387
+ The number of top words.
388
+ """
389
+ return self._topn
390
+
391
+ @topn.setter
392
+ def topn(self, topn):
393
+ """
394
+ Set the number of top words (`_topn`) to consider for coherence calculation.
395
+ If the new `topn` requires more words than currently loaded topics, and a model is available,
396
+ it will attempt to re-extract topics from the model.
397
+
398
+ Parameters
399
+ ----------
400
+ topn : int
401
+ The new number of top words.
402
+ """
403
+ # Get the length of the first topic to check current topic length
404
+ current_topic_length = len(self._topics[0])
405
+ # Determine if the new 'topn' requires more words than currently available in topics
406
+ requires_expansion = current_topic_length < topn
407
+
408
+ if self.model is not None:
409
+ self._topn = topn
410
+ if requires_expansion:
411
+ # If expansion is needed and a model is available, re-extract topics from the model.
412
+ # This call to the setter property `self.model = self._model` effectively re-runs
413
+ # the logic that extracts topics and updates the accumulator based on the new `_topn`.
414
+ self.model = self._model
415
+ else:
416
+ # If no model is available and expansion is required, raise an error
417
+ if requires_expansion:
418
+ raise ValueError("Model unavailable and topic sizes are less than topn=%d" % topn)
419
+ self._topn = topn # Topics will be truncated by the `topics` getter if needed
420
+
421
+ @property
422
+ def measure(self):
423
+ """
424
+ Returns the namedtuple representing the coherence pipeline functions
425
+ (segmentation, probability estimation, confirmation, aggregation)
426
+ based on the `self.coherence` type.
427
+
428
+ Returns
429
+ -------
430
+ namedtuple
431
+ Pipeline that contains needed functions/method for calculating coherence.
432
+ """
433
+ return COHERENCE_MEASURES[self.coherence]
434
+
435
+ @property
436
+ def topics(self):
437
+ """
438
+ Get the current topics. If the internally stored topics have more words
439
+ than `self._topn`, they are truncated to `self._topn` words.
440
+
441
+ Returns
442
+ -------
443
+ list of list of str
444
+ Topics as lists of word tokens.
445
+ """
446
+ # If the stored topics contain more words than `_topn`, truncate them
447
+ if len(self._topics[0]) > self._topn:
448
+ return [topic[:self._topn] for topic in self._topics]
449
+ else:
450
+ return self._topics
451
+
452
+ @topics.setter
453
+ def topics(self, topics):
454
+ """
455
+ Set the topics for the instance. This method converts topic words to their
456
+ corresponding dictionary IDs and updates the accumulator state.
457
+
458
+ Parameters
459
+ ----------
460
+ topics : list of list of str or list of list of int
461
+ Topics, either as lists of word tokens or lists of word IDs.
462
+ """
463
+ if topics is not None:
464
+ new_topics = []
465
+ for topic in topics:
466
+ # Ensure topic elements are converted to dictionary IDs (numpy array for efficiency)
467
+ topic_token_ids = self._ensure_elements_are_ids(topic)
468
+ new_topics.append(topic_token_ids)
469
+
470
+ if self.model is not None:
471
+ # Warn if both model and explicit topics are set, as they might be inconsistent
472
+ logger.warning(
473
+ "The currently set model '%s' may be inconsistent with the newly set topics",
474
+ self.model)
475
+ elif self.model is not None:
476
+ # If topics are None but a model exists, extract topics from the model
477
+ new_topics = self._get_topics()
478
+ logger.debug("Setting topics to those of the model: %s", self.model)
479
+ else:
480
+ new_topics = None
481
+
482
+ # Check if the accumulator needs to be recomputed based on the new topics
483
+ self._update_accumulator(new_topics)
484
+ self._topics = new_topics # Store the (ID-converted) topics
485
+
486
+ def _ensure_elements_are_ids(self, topic):
487
+ """
488
+ Internal helper to ensure that topic elements are converted to dictionary IDs.
489
+ Handles cases where input topic might be tokens or already IDs.
490
+
491
+ Parameters
492
+ ----------
493
+ topic : list of str or list of int
494
+ A single topic, either as a list of word tokens or word IDs.
495
+
496
+ Returns
497
+ -------
498
+ :class:`numpy.ndarray`
499
+ A numpy array of word IDs for the topic.
500
+
501
+ Raises
502
+ ------
503
+ KeyError
504
+ If a token is not found in the dictionary or an ID is not a valid key in id2token.
505
+ """
506
+ try:
507
+ # Try to convert tokens to IDs. This is the common case if `topic` contains strings.
508
+ return np.array([self.dictionary.token2id[token] for token in topic if token in self.dictionary.token2id])
509
+ except KeyError:
510
+ # If `KeyError` occurs, assume `topic` might already be a list of IDs.
511
+ # Attempt to convert IDs to tokens and then back to IDs, ensuring they are valid dictionary entries.
512
+ # This handles cases where `topic` might contain integer IDs that are not present in the dictionary.
513
+ try:
514
+ # Convert IDs to tokens (via id2token) and then tokens to IDs (via token2id)
515
+ # This filters out invalid IDs.
516
+ return np.array([self.dictionary.token2id[self.dictionary.id2token[_id]]
517
+ for _id in topic if _id in self.dictionary])
518
+ except KeyError:
519
+ raise ValueError("Unable to interpret topic as either a list of tokens or a list of valid IDs within the dictionary.")
520
+
521
+ def _update_accumulator(self, new_topics):
522
+ """
523
+ Internal helper to determine if the cached `_accumulator` (probability statistics)
524
+ needs to be wiped and recomputed due to changes in topics.
525
+ """
526
+ if self._relevant_ids_will_differ(new_topics):
527
+ logger.debug("Wiping cached accumulator since it does not contain all relevant ids.")
528
+ self._accumulator = None
529
+
530
+ def _relevant_ids_will_differ(self, new_topics):
531
+ """
532
+ Internal helper to check if the set of unique word IDs relevant to the new topics
533
+ is different from the IDs already covered by the current accumulator.
534
+
535
+ Parameters
536
+ ----------
537
+ new_topics : list of list of int
538
+ The new set of topics (as word IDs).
539
+
540
+ Returns
541
+ -------
542
+ bool
543
+ True if the relevant IDs will differ, False otherwise.
544
+ """
545
+ if self._accumulator is None or not self._topics_differ(new_topics):
546
+ return False
547
+
548
+ # Get unique IDs from the segmented new topics
549
+ new_set = unique_ids_from_segments(self.measure.seg(new_topics))
550
+ # Check if the current accumulator's relevant IDs are a superset of the new set.
551
+ # If not, it means the new topics introduce words not covered, so the accumulator needs updating.
552
+ return not self._accumulator.relevant_ids.issuperset(new_set)
553
+
554
+ def _topics_differ(self, new_topics):
555
+ """
556
+ Internal helper to check if the new topics are different from the currently stored topics.
557
+
558
+ Parameters
559
+ ----------
560
+ new_topics : list of list of int
561
+ The new set of topics (as word IDs).
562
+
563
+ Returns
564
+ -------
565
+ bool
566
+ True if topics are different, False otherwise.
567
+ """
568
+ # Compare topic arrays using numpy.array_equal for efficient comparison
569
+ return (new_topics is not None
570
+ and self._topics is not None
571
+ and not np.array_equal(new_topics, self._topics))
572
+
573
+ def _get_topics(self):
574
+ """
575
+ Internal helper function to extract top words (as IDs) from a trained topic model.
576
+ """
577
+ return self._get_topics_from_model(self.model, self.topn)
578
+
579
+ @staticmethod
580
+ def _get_topics_from_model(model, topn):
581
+ """
582
+ Internal static method to extract top `topn` words (as IDs) from a trained topic model.
583
+
584
+ Parameters
585
+ ----------
586
+ model : :class:`~gensim.models.basemodel.BaseTopicModel`
587
+ Pre-trained topic model (must implement `get_topics` method).
588
+ topn : int
589
+ Integer corresponding to the number of top words to extract.
590
+
591
+ Returns
592
+ -------
593
+ list of :class:`numpy.ndarray`
594
+ A list where each element is a numpy array of word IDs representing a topic's top words.
595
+
596
+ Raises
597
+ ------
598
+ AttributeError
599
+ If the provided model does not implement a `get_topics` method.
600
+ """
601
+ try:
602
+ # Iterate over the topic distributions from the model
603
+ # Use matutils.argsort to get the indices (word IDs) of the top `topn` words
604
+ return [
605
+ matutils.argsort(topic, topn=topn, reverse=True) for topic in
606
+ model.get_topics()
607
+ ]
608
+ except AttributeError:
609
+ raise ValueError(
610
+ "This topic model is not currently supported. Supported topic models"
611
+ " should implement the `get_topics` method.")
612
+
613
+ def segment_topics(self):
614
+ """
615
+ Segments the current topics using the segmentation function defined by the
616
+ chosen coherence measure (`self.measure.seg`).
617
+
618
+ Returns
619
+ -------
620
+ list of list of tuple
621
+ Segmented topics. The structure depends on the segmentation method (e.g., pairs of word IDs).
622
+ """
623
+ # Apply the segmentation function from the pipeline to the current topics
624
+ return self.measure.seg(self.topics)
625
+
626
+ def estimate_probabilities(self, segmented_topics=None):
627
+ """
628
+ Accumulates word occurrences and co-occurrences from texts or corpus
629
+ using the optimal probability estimation method for the chosen coherence metric.
630
+ This operation can be computationally intensive, especially for sliding window methods.
631
+
632
+ Parameters
633
+ ----------
634
+ segmented_topics : list of list of tuple, optional
635
+ Segmented topics. If None, `self.segment_topics()` is called internally.
636
+
637
+ Returns
638
+ -------
639
+ :class:`~gensim.topic_coherence.text_analysis.CorpusAccumulator`
640
+ An object that holds the accumulated statistics (word frequencies, co-occurrence frequencies).
641
+ """
642
+ if segmented_topics is None:
643
+ segmented_topics = self.segment_topics()
644
+
645
+ # Choose the appropriate probability estimation method based on the coherence type
646
+ if self.coherence in BOOLEAN_DOCUMENT_BASED:
647
+ self._accumulator = self.measure.prob(self.corpus, segmented_topics)
648
+ else:
649
+ kwargs = dict(
650
+ texts=self.texts, segmented_topics=segmented_topics,
651
+ dictionary=self.dictionary, window_size=self.window_size,
652
+ processes=self.processes)
653
+ if self.coherence == 'c_w2v':
654
+ kwargs['model'] = self.keyed_vectors # Pass keyed_vectors for word2vec based coherence
655
+
656
+ self._accumulator = self.measure.prob(**kwargs)
657
+
658
+ return self._accumulator
659
+
660
+ def get_coherence_per_topic(self, segmented_topics=None, with_std=False, with_support=False):
661
+ """
662
+ Calculates and returns a list of coherence values, one for each topic,
663
+ based on the pipeline's confirmation measure.
664
+
665
+ Parameters
666
+ ----------
667
+ segmented_topics : list of list of tuple, optional
668
+ Segmented topics. If None, `self.segment_topics()` is called internally.
669
+ with_std : bool, optional
670
+ If True, also includes the standard deviation across topic segment sets in addition
671
+ to the mean coherence for each topic. Defaults to False.
672
+ with_support : bool, optional
673
+ If True, also includes the "support" (number of pairwise similarity comparisons)
674
+ used to compute each topic's coherence. Defaults to False.
675
+
676
+ Returns
677
+ -------
678
+ list of float or list of tuple
679
+ A sequence of similarity measures for each topic.
680
+ If `with_std` or `with_support` is True, each element in the list will be a tuple
681
+ containing the coherence value and the requested additional statistics.
682
+ """
683
+ measure = self.measure
684
+ if segmented_topics is None:
685
+ segmented_topics = measure.seg(self.topics)
686
+
687
+ # Ensure probabilities are estimated before calculating coherence
688
+ if self._accumulator is None:
689
+ self.estimate_probabilities(segmented_topics)
690
+
691
+ kwargs = dict(with_std=with_std, with_support=with_support)
692
+ if self.coherence in BOOLEAN_DOCUMENT_BASED or self.coherence == 'c_w2v':
693
+ # These coherence types don't require specific additional kwargs for confirmation measure
694
+ pass
695
+ elif self.coherence == 'c_v':
696
+ # Specific kwargs for c_v's confirmation measure (cosine_similarity)
697
+ kwargs['topics'] = self.topics
698
+ kwargs['measure'] = 'nlr' # Normalized Log Ratio
699
+ kwargs['gamma'] = 1
700
+ else:
701
+ # For c_uci and c_npmi, 'normalize' parameter is relevant
702
+ kwargs['normalize'] = (self.coherence == 'c_npmi')
703
+
704
+ return measure.conf(segmented_topics, self._accumulator, **kwargs)
705
+
706
+ def aggregate_measures(self, topic_coherences):
707
+ """
708
+ Aggregates the individual topic coherence measures into a single overall score
709
+ using the pipeline's aggregation function (`self.measure.aggr`).
710
+
711
+ Parameters
712
+ ----------
713
+ topic_coherences : list of float
714
+ List of coherence values for each topic.
715
+
716
+ Returns
717
+ -------
718
+ float
719
+ The aggregated coherence value (e.g., arithmetic mean).
720
+ """
721
+ # Apply the aggregation function from the pipeline to the list of topic coherences
722
+ return self.measure.aggr(topic_coherences)
723
+
724
+ def get_coherence(self):
725
+ """
726
+ Calculates and returns the overall coherence value for the entire set of topics.
727
+ This is the main entry point for getting a single coherence score.
728
+
729
+ Returns
730
+ -------
731
+ float
732
+ The aggregated coherence value.
733
+ """
734
+ # First, get coherence values for each individual topic
735
+ confirmed_measures = self.get_coherence_per_topic()
736
+ # Then, aggregate these topic-level coherences into a single score
737
+ return self.aggregate_measures(confirmed_measures)
738
+
739
+ def compare_models(self, models):
740
+ """
741
+ Compares multiple topic models by their coherence values.
742
+ It extracts topics from each model and then calls `compare_model_topics`.
743
+
744
+ Parameters
745
+ ----------
746
+ models : list of :class:`~gensim.models.basemodel.BaseTopicModel`
747
+ A sequence of topic models to compare.
748
+
749
+ Returns
750
+ -------
751
+ list of (list of float, float)
752
+ A sequence where each element is a pair:
753
+ (list of average topic coherences for the model, overall model coherence).
754
+ """
755
+ # Extract topics (as word IDs) for each model using the internal helper
756
+ model_topics = [self._get_topics_from_model(model, self.topn) for model in models]
757
+ # Delegate to compare_model_topics for the actual coherence comparison
758
+ return self.compare_model_topics(model_topics)
759
+
760
+ def compare_model_topics(self, model_topics):
761
+ """
762
+ Performs coherence evaluation for each set of topics provided in `model_topics`.
763
+ This method is designed to be efficient by precomputing probabilities once if needed,
764
+ and then evaluating coherence for each set of topics.
765
+
766
+ Parameters
767
+ ----------
768
+ model_topics : list of list of list of int
769
+ A list where each element is itself a list of topics (each topic being a list of word IDs)
770
+ representing a set of topics (e.g., from a single model).
771
+
772
+ Returns
773
+ -------
774
+ list of (list of float, float)
775
+ A sequence where each element is a pair:
776
+ (list of average topic coherences for the topic set, overall topic set coherence).
777
+
778
+ Notes
779
+ -----
780
+ This method uses a heuristic of evaluating coherence at various `topn` values (e.g., 20, 15, 10, 5)
781
+ and averaging the results for robustness, as suggested in some research.
782
+ """
783
+ # Store original topics and topn to restore them after comparison
784
+ orig_topics = self._topics
785
+ orig_topn = self.topn
786
+
787
+ try:
788
+ # Perform the actual comparison
789
+ coherences = self._compare_model_topics(model_topics)
790
+ finally:
791
+ # Ensure original topics and topn are restored even if an error occurs
792
+ self.topics = orig_topics
793
+ self.topn = orig_topn
794
+
795
+ return coherences
796
+
797
+ def _compare_model_topics(self, model_topics):
798
+ """
799
+ Internal helper to get average topic and model coherences across multiple sets of topics.
800
+
801
+ Parameters
802
+ ----------
803
+ model_topics : list of list of list of int
804
+ A list where each element is a set of topics (list of lists of word IDs).
805
+
806
+ Returns
807
+ -------
808
+ list of (list of float, float)
809
+ A sequence of pairs:
810
+ (average topic coherences across different `topn` values for each topic,
811
+ overall model coherence averaged across different `topn` values).
812
+ """
813
+ coherences = []
814
+ # Define a grid of `topn` values to evaluate coherence.
815
+ # This provides a more robust average coherence value.
816
+ # It goes from `self.topn` down to `min(self.topn - 1, 4)` in steps of -5.
817
+ # e.g., if self.topn is 20, grid might be [20, 15, 10, 5].
818
+ # The `min(self.topn - 1, 4)` ensures at least some lower values are included,
819
+ # but also prevents trying `topn` values that are too small or negative.
820
+ last_topn_value = min(self.topn - 1, 4)
821
+ topn_grid = list(range(self.topn, last_topn_value, -5))
822
+ if not topn_grid or max(topn_grid) < 1: # Ensure at least one valid topn if range is empty or too small
823
+ topn_grid = [max(1, min(self.topn, 5))] # Use min of self.topn and 5, ensure at least 1
824
+
825
+ for model_num, topics in enumerate(model_topics):
826
+ # Set the current topics for the instance to the topics of the model being evaluated
827
+ self.topics = topics
828
+
829
+ coherence_at_n = {} # Dictionary to store coherence results for different `topn` values
830
+ for n in topn_grid:
831
+ self.topn = n # Set the `topn` for the current evaluation round
832
+ topic_coherences = self.get_coherence_per_topic()
833
+
834
+ # Handle NaN values in topic coherences by imputing with the mean
835
+ filled_coherences = np.array(topic_coherences, dtype=float)
836
+ # Check for NaN values and replace them with the mean of non-NaN values.
837
+ # np.nanmean handles arrays with all NaNs gracefully by returning NaN.
838
+ if np.any(np.isnan(filled_coherences)):
839
+ mean_val = np.nanmean(filled_coherences)
840
+ if np.isnan(mean_val): # If all are NaN, mean_val will also be NaN. In this case, replace with 0 or a very small number.
841
+ filled_coherences[np.isnan(filled_coherences)] = 0.0 # Or another sensible default
842
+ else:
843
+ filled_coherences[np.isnan(filled_coherences)] = mean_val
844
+
845
+
846
+ # Store the topic-level coherences and the aggregated (overall) coherence for this `topn`
847
+ coherence_at_n[n] = (topic_coherences, self.aggregate_measures(filled_coherences))
848
+
849
+ # Unpack the stored coherences for different `topn` values
850
+ all_topic_coherences_at_n, all_avg_coherences_at_n = zip(*coherence_at_n.values())
851
+
852
+ # Calculate the average topic coherence across all `topn` values
853
+ # np.vstack stacks lists of topic coherences into a 2D array, then mean(0) computes mean for each topic.
854
+ avg_topic_coherences = np.vstack(all_topic_coherences_at_n).mean(axis=0)
855
+
856
+ # Calculate the overall model coherence by averaging the aggregated coherences from all `topn` values
857
+ model_coherence = np.mean(all_avg_coherences_at_n)
858
+
859
+ logging.info("Avg coherence for model %d: %.5f" % (model_num, model_coherence))
860
+ coherences.append((avg_topic_coherences.tolist(), model_coherence)) # Convert numpy array back to list for output
861
+
862
+ return coherences
backend/evaluation/eval.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dynamic_topic_quality.py
2
+ import numpy as np
3
+ import pandas as pd
4
+ from gensim.corpora.dictionary import Dictionary
5
+ from gensim.models.coherencemodel import CoherenceModel
6
+ from backend.evaluation.CoherenceModel_ttc import CoherenceModel_ttc
7
+ from typing import List, Dict
8
+
9
+ class TopicQualityAssessor:
10
+ """
11
+ Calculates various quality metrics for dynamic topic models from in-memory data.
12
+
13
+ This class provides methods to compute:
14
+ - Temporal Topic Coherence (TTC)
15
+ - Temporal Topic Smoothness (TTS)
16
+ - Temporal Topic Quality (TTQ)
17
+ - Yearly Topic Coherence (TC)
18
+ - Yearly Topic Diversity (TD)
19
+ - Yearly Topic Quality (TQ)
20
+ """
21
+
22
+ def __init__(self, topics: List[List[List[str]]], train_texts: List[List[str]], topn: int, coherence_type: str):
23
+ """
24
+ Initializes the TopicQualityAssessor with data in memory.
25
+
26
+ Args:
27
+ topics (List[List[List[str]]]): A nested list of topics with structure (T, K, W),
28
+ where T is time slices, K is topics, and W is words.
29
+ train_texts (List[List[str]]): A list of tokenized documents for the reference corpus.
30
+ topn (int): Number of top words per topic to consider for calculations.
31
+ coherence_type (str): The type of coherence to calculate (e.g., 'c_npmi', 'c_v').
32
+ """
33
+ # 1. Set texts and dictionary
34
+ self.texts = train_texts
35
+ self.dictionary = Dictionary(self.texts)
36
+
37
+ # 2. Process topics
38
+ # User provides topics as (T, K, W) -> List[timestamps][topics][words]
39
+ # Internal representation for temporal evolution is (K, T, W)
40
+ topics_array_T_K_W = np.array(topics, dtype=object)
41
+ if topics_array_T_K_W.ndim != 3:
42
+ raise ValueError(f"Input 'topics' must be a 3-dimensional list/array. Got {topics_array_T_K_W.ndim} dimensions.")
43
+ self.total_topics = topics_array_T_K_W.transpose(1, 0, 2) # Shape: (K, T, W)
44
+
45
+ # 3. Get dimensions
46
+ self.K, self.T, _ = self.total_topics.shape
47
+
48
+ # 4. Create topic groups for smoothness calculation (pairs of topics over time)
49
+ groups = []
50
+ for k in range(self.K):
51
+ time_pairs = []
52
+ for t in range(self.T - 1):
53
+ time_pairs.append([self.total_topics[k, t].tolist(), self.total_topics[k, t+1].tolist()])
54
+ groups.append(time_pairs)
55
+ self.group_topics = np.array(groups, dtype=object)
56
+
57
+ # 5. Create yearly topics (T, K, W) for TC/TD calculation
58
+ self.yearly_topics = self.total_topics.transpose(1, 0, 2)
59
+
60
+ # 6. Set parameters
61
+ self.topn = topn
62
+ self.coherence_type = coherence_type
63
+
64
+ def _compute_coherence(self, topics: List[List[str]]) -> List[float]:
65
+ cm = CoherenceModel(
66
+ topics=topics, texts=self.texts, dictionary=self.dictionary,
67
+ coherence=self.coherence_type, topn=self.topn
68
+ )
69
+ return cm.get_coherence_per_topic()
70
+
71
+ def _compute_coherence_ttc(self, topics: List[List[str]]) -> List[float]:
72
+ cm = CoherenceModel_ttc(
73
+ topics=topics, texts=self.texts, dictionary=self.dictionary,
74
+ coherence=self.coherence_type, topn=self.topn
75
+ )
76
+ return cm.get_coherence_per_topic()
77
+
78
+ def _topic_smoothness(self, topics: List[List[str]]) -> float:
79
+ K = len(topics)
80
+ if K <= 1:
81
+ return 1.0 # Or 0.0, depending on definition. A single topic has no other topic to be dissimilar to.
82
+ scores = []
83
+ for i, base in enumerate(topics):
84
+ base_set = set(base[:self.topn])
85
+ others = [other for j, other in enumerate(topics) if j != i]
86
+ if not others:
87
+ return 1.0
88
+ overlaps = [len(base_set & set(other[:self.topn])) / self.topn for other in others]
89
+ scores.append(sum(overlaps) / len(overlaps))
90
+ return float(sum(scores) / K)
91
+
92
+ def get_ttq_dataframe(self) -> pd.DataFrame:
93
+ """Computes and returns a DataFrame with detailed TTQ metrics per topic chain."""
94
+ all_coh_scores, avg_coh_scores = [], []
95
+ for k in range(self.K):
96
+ coh_per_topic = self._compute_coherence_ttc(self.total_topics[k].tolist())
97
+ all_coh_scores.append(coh_per_topic)
98
+ avg_coh_scores.append(float(np.mean(coh_per_topic)))
99
+
100
+ all_smooth_scores, avg_smooth_scores = [], []
101
+ for k in range(self.K):
102
+ pair_scores = [self._topic_smoothness(pair) for pair in self.group_topics[k]]
103
+ all_smooth_scores.append(pair_scores)
104
+ avg_smooth_scores.append(float(np.mean(pair_scores)))
105
+
106
+ df = pd.DataFrame({
107
+ 'topic_idx': list(range(self.K)),
108
+ 'temporal_coherence': all_coh_scores,
109
+ 'temporal_smoothness': all_smooth_scores,
110
+ 'avg_temporal_coherence': avg_coh_scores,
111
+ 'avg_temporal_smoothness': avg_smooth_scores
112
+ })
113
+ df['ttq_product'] = df['avg_temporal_coherence'] * df['avg_temporal_smoothness']
114
+ return df
115
+
116
+ def get_tq_dataframe(self) -> pd.DataFrame:
117
+ """Computes and returns a DataFrame with detailed TQ metrics per time slice."""
118
+ all_coh, avg_coh, div = [], [], []
119
+ for t in range(self.T):
120
+ yearly_t_topics = self.yearly_topics[t].tolist()
121
+ coh_per_topic = self._compute_coherence(yearly_t_topics)
122
+ all_coh.append(coh_per_topic)
123
+ avg_coh.append(float(np.mean(coh_per_topic)))
124
+ div.append(1 - self._topic_smoothness(yearly_t_topics))
125
+
126
+ df = pd.DataFrame({
127
+ 'year': list(range(self.T)),
128
+ 'all_coherence': all_coh,
129
+ 'avg_coherence': avg_coh,
130
+ 'diversity': div
131
+ })
132
+ df['tq_product'] = df['avg_coherence'] * df['diversity']
133
+ return df
134
+
135
+ def get_ttc_score(self) -> float:
136
+ """Calculates the overall Temporal Topic Coherence (TTC)."""
137
+ ttq_df = self.get_ttq_dataframe()
138
+ return ttq_df['avg_temporal_coherence'].mean()
139
+
140
+ def get_tts_score(self) -> float:
141
+ """Calculates the overall Temporal Topic Smoothness (TTS)."""
142
+ ttq_df = self.get_ttq_dataframe()
143
+ return ttq_df['avg_temporal_smoothness'].mean()
144
+
145
+ def get_ttq_score(self) -> float:
146
+ """Calculates the overall Temporal Topic Quality (TTQ)."""
147
+ ttq_df = self.get_ttq_dataframe()
148
+ return ttq_df['ttq_product'].mean()
149
+
150
+ def get_tc_score(self) -> float:
151
+ """Calculates the overall yearly Topic Coherence (TC)."""
152
+ tq_df = self.get_tq_dataframe()
153
+ return tq_df['avg_coherence'].mean()
154
+
155
+ def get_td_score(self) -> float:
156
+ """Calculates the overall yearly Topic Diversity (TD)."""
157
+ tq_df = self.get_tq_dataframe()
158
+ return tq_df['diversity'].mean()
159
+
160
+ def get_tq_score(self) -> float:
161
+ """Calculates the overall yearly Topic Quality (TQ)."""
162
+ tq_df = self.get_tq_dataframe()
163
+ return tq_df['tq_product'].mean()
164
+
165
+ def get_dtq_summary(self) -> Dict[str, float]:
166
+ """
167
+ Computes all dynamic topic quality metrics and returns them in a dictionary.
168
+ """
169
+ ttq_df = self.get_ttq_dataframe()
170
+ tq_df = self.get_tq_dataframe()
171
+ summary = {
172
+ 'TTC': ttq_df['avg_temporal_coherence'].mean(),
173
+ 'TTS': ttq_df['avg_temporal_smoothness'].mean(),
174
+ 'TTQ': ttq_df['ttq_product'].mean(),
175
+ 'TC': tq_df['avg_coherence'].mean(),
176
+ 'TD': tq_df['diversity'].mean(),
177
+ 'TQ': tq_df['tq_product'].mean()
178
+ }
179
+ return summary
backend/inference/doc_retriever.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import html
2
+ import json
3
+ import re
4
+ import os
5
+ from hashlib import md5
6
+
7
+ def deduplicate_docs(collected_docs):
8
+ seen = set()
9
+ unique_docs = []
10
+ for doc in collected_docs:
11
+ # Prefer unique ID if available
12
+ key = doc.get("id", md5(doc["text"].encode()).hexdigest())
13
+ if key not in seen:
14
+ seen.add(key)
15
+ unique_docs.append(doc)
16
+ return unique_docs
17
+
18
+ def load_length_stats(length_stats_path):
19
+ """
20
+ Loads length statistics from a JSON file for a given model path.
21
+
22
+ Args:
23
+ path (str): Path to the model directory containing 'length_stats.json'.
24
+
25
+ Returns:
26
+ dict: A dictionary containing document length statistics.
27
+ """
28
+ if not os.path.exists(length_stats_path):
29
+ raise FileNotFoundError(f"'length_stats.json' not found at: {length_stats_path}")
30
+
31
+ with open(length_stats_path, "r") as f:
32
+ length_stats = json.load(f)
33
+
34
+ return length_stats
35
+
36
+ def get_yearly_counts_for_word(index, word):
37
+ if word not in index:
38
+ print(f"[ERROR] Word '{word}' not found in index.")
39
+ return [], []
40
+
41
+ year_counts = index[word]
42
+ sorted_items = sorted((int(year), len(doc_ids)) for year, doc_ids in year_counts.items())
43
+ years, counts = zip(*sorted_items) if sorted_items else ([], [])
44
+ return list(years), list(counts)
45
+
46
+
47
+ def get_all_documents_for_word_year(index, docs_file_path, word, year):
48
+ """
49
+ Returns all full documents (text + metadata) that contain a given word in a given year.
50
+
51
+ Parameters:
52
+ index (dict): Inverted index.
53
+ docs_file_path (str): Path to original jsonl corpus.
54
+ word (str): Word (unigram or bigram).
55
+ year (int): Year to retrieve docs for.
56
+
57
+ Returns:
58
+ List[Dict]: List of documents with 'id', 'timestamp', and 'text'.
59
+ """
60
+ year = int(year)
61
+
62
+ if word not in index or year not in index[word]:
63
+ return []
64
+
65
+ doc_ids = set(index[word][year])
66
+ results = []
67
+
68
+ try:
69
+ with open(docs_file_path, 'r', encoding='utf-8') as f:
70
+ for doc_id, line in enumerate(f):
71
+ if doc_id in doc_ids:
72
+ doc = json.loads(line)
73
+ results.append({
74
+ "id": doc_id,
75
+ "timestamp": doc.get("timestamp", "N/A"),
76
+ "text": doc["text"]
77
+ })
78
+ except Exception as e:
79
+ print(f"[ERROR] Could not load documents: {e}")
80
+
81
+ return results
82
+
83
+
84
+ def get_documents_with_all_words_for_year(index, docs_path, words, year):
85
+ doc_sets = []
86
+ all_doc_occurrences = {}
87
+
88
+ for word in words:
89
+ word_docs = get_all_documents_for_word_year(index, docs_path, word, year)
90
+ doc_sets.append(set(doc["id"] for doc in word_docs))
91
+ for doc in word_docs:
92
+ all_doc_occurrences.setdefault(doc["id"], doc)
93
+
94
+ common_doc_ids = set.intersection(*doc_sets) if doc_sets else set()
95
+ return [all_doc_occurrences[doc_id] for doc_id in common_doc_ids]
96
+
97
+
98
+ def get_intersection_doc_counts_by_year(index, docs_path, words, all_years):
99
+ year_counts = {}
100
+ for y in all_years:
101
+ docs = get_documents_with_all_words_for_year(index, docs_path, words, y)
102
+ year_counts[y] = len(docs)
103
+ return year_counts
104
+
105
+
106
+ def extract_snippet(text, query, window=30):
107
+ """
108
+ Return a short snippet around the first occurrence of the query word.
109
+ """
110
+ pattern = re.compile(re.escape(query.replace('_', ' ')), re.IGNORECASE)
111
+ match = pattern.search(text)
112
+ if not match:
113
+ return text[:200] + "..."
114
+
115
+ start = max(match.start() - window, 0)
116
+ end = min(match.end() + window, len(text))
117
+ snippet = text[start:end].strip()
118
+
119
+ return f"...{snippet}..."
120
+
121
+ def highlight(text, query, highlight_color="#FFD54F"):
122
+ """
123
+ Highlight all instances of the query term in text using a colored <mark> tag.
124
+ """
125
+ escaped_query = re.escape(query.replace('_', ' '))
126
+ pattern = re.compile(f"({escaped_query})", flags=re.IGNORECASE)
127
+
128
+ def replacer(match):
129
+ matched_text = html.escape(match.group(1))
130
+ return f"<mark style='background-color:{highlight_color}; color:black;'>{matched_text}</mark>"
131
+
132
+ return pattern.sub(replacer, html.escape(text))
133
+
134
+ def highlight_words(text, query_words, highlight_color="#24F31D", lemma_to_forms=None):
135
+ """
136
+ Highlight all surface forms of each query lemma in the text using a colored <mark> tag.
137
+
138
+ Args:
139
+ text (str): The input raw document text.
140
+ query_words (List[str]): Lemmatized query tokens to highlight.
141
+ highlight_color (str): Color to use for highlighting.
142
+ lemma_to_forms (Dict[str, Set[str]]): Maps a lemma to its surface forms.
143
+ """
144
+ # Escape HTML special characters first
145
+ escaped_text = html.escape(text)
146
+
147
+ # Expand query words to include all surface forms
148
+ expanded_forms = set()
149
+ for lemma in query_words:
150
+ if lemma_to_forms and lemma in lemma_to_forms:
151
+ expanded_forms.update(lemma_to_forms[lemma])
152
+ else:
153
+ expanded_forms.add(lemma) # Fallback if map is missing
154
+
155
+ # Sort by length to avoid partial overlaps (e.g., "run" before "running")
156
+ sorted_queries = sorted(expanded_forms, key=lambda w: -len(w))
157
+
158
+ for word in sorted_queries:
159
+ # Match full word, case insensitive
160
+ pattern = re.compile(rf'\b({re.escape(word)})\b', flags=re.IGNORECASE)
161
+
162
+ def replacer(match):
163
+ matched_text = match.group(1)
164
+ return f"<mark style='background-color:{highlight_color}; color:black;'>{matched_text}</mark>"
165
+
166
+ escaped_text = pattern.sub(replacer, escaped_text)
167
+
168
+ return escaped_text
169
+
170
+ def get_docs_by_ids(docs_file_path, doc_ids):
171
+ """
172
+ Efficiently retrieves specific documents from a .jsonl file by their line number (ID).
173
+
174
+ This function reads the file line-by-line and only parses the lines that match
175
+ the requested document IDs, avoiding loading the entire file into memory.
176
+
177
+ Args:
178
+ docs_file_path (str): The path to the documents.jsonl file.
179
+ doc_ids (list or set): A collection of document IDs (0-indexed line numbers) to retrieve.
180
+
181
+ Returns:
182
+ list[dict]: A list of document dictionaries that were found. Each dictionary
183
+ is augmented with an 'id' key corresponding to its line number.
184
+ """
185
+ # Use a set for efficient O(1) lookups.
186
+ doc_ids_to_find = set(doc_ids)
187
+ found_docs = {}
188
+
189
+ if not doc_ids_to_find:
190
+ return []
191
+
192
+ try:
193
+ with open(docs_file_path, 'r', encoding='utf-8') as f:
194
+ for i, line in enumerate(f):
195
+ # If the current line number is one we're looking for
196
+ if i in doc_ids_to_find:
197
+ try:
198
+ doc = json.loads(line)
199
+ # Explicitly add the line number as the 'id'
200
+ doc['id'] = i
201
+ found_docs[i] = doc
202
+ # Optimization: stop reading the file once all docs are found
203
+ if len(found_docs) == len(doc_ids_to_find):
204
+ break
205
+ except json.JSONDecodeError:
206
+ # Skip malformed lines but inform the user
207
+ print(f"[WARNING] Skipping malformed JSON on line {i+1} in {docs_file_path}")
208
+ continue
209
+
210
+ except FileNotFoundError:
211
+ print(f"[ERROR] Document file not found at: {docs_file_path}")
212
+ return []
213
+ except Exception as e:
214
+ print(f"[ERROR] An unexpected error occurred while reading documents: {e}")
215
+ return []
216
+
217
+ # Return the documents in the same order as the original doc_ids list
218
+ # This ensures consistency for downstream processing.
219
+ return [found_docs[doc_id] for doc_id in doc_ids if doc_id in found_docs]
backend/inference/indexing_utils.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import re
4
+ import spacy
5
+ from collections import defaultdict
6
+
7
+ # Load spaCy once
8
+ nlp = spacy.load("en_core_web_sm", disable=["parser", "ner"])
9
+
10
+ def tokenize(text):
11
+ return re.findall(r"\b\w+\b", text.lower())
12
+
13
+ def has_bigram(tokens, bigram):
14
+ parts = bigram.split('_')
15
+ for i in range(len(tokens) - len(parts) + 1):
16
+ if tokens[i:i + len(parts)] == parts:
17
+ return True
18
+ return False
19
+
20
+ def build_inverse_lemma_map(docs_file_path, cache_path=None):
21
+ """
22
+ Build or load a mapping from lemma -> set of surface forms seen in corpus.
23
+ If cache_path is provided and exists, loads from it.
24
+ Else builds from scratch and saves to cache_path.
25
+ """
26
+ if cache_path and os.path.exists(cache_path):
27
+ print(f"[INFO] Loading cached lemma_to_forms from {cache_path}")
28
+ with open(cache_path, "r", encoding="utf-8") as f:
29
+ raw_map = json.load(f)
30
+ return {lemma: set(forms) for lemma, forms in raw_map.items()}
31
+
32
+ print(f"[INFO] Building inverse lemma map from {docs_file_path}...")
33
+ lemma_to_forms = defaultdict(set)
34
+
35
+ with open(docs_file_path, 'r', encoding='utf-8') as f:
36
+ for line in f:
37
+ doc = json.loads(line)
38
+ tokens = tokenize(doc['text'])
39
+ spacy_doc = nlp(" ".join(tokens))
40
+ for token in spacy_doc:
41
+ lemma_to_forms[token.lemma_].add(token.text.lower())
42
+
43
+ if cache_path:
44
+ print(f"[INFO] Saving lemma_to_forms to {cache_path}")
45
+ os.makedirs(os.path.dirname(cache_path), exist_ok=True)
46
+ with open(cache_path, "w", encoding="utf-8") as f:
47
+ json.dump({k: list(v) for k, v in lemma_to_forms.items()}, f, indent=2)
48
+
49
+ return lemma_to_forms
50
+
51
+ def build_inverted_index(docs_file_path, vocab_set, lemma_map_path=None):
52
+ vocab_unigrams = {w for w in vocab_set if '_' not in w}
53
+ vocab_bigrams = {w for w in vocab_set if '_' in w}
54
+
55
+ # Load or build lemma map
56
+ lemma_to_forms = build_inverse_lemma_map(docs_file_path, cache_path=lemma_map_path)
57
+
58
+ index = defaultdict(lambda: defaultdict(list))
59
+ docs = []
60
+ global_seen_words = set()
61
+
62
+ with open(docs_file_path, 'r', encoding='utf-8') as f:
63
+ for doc_id, line in enumerate(f):
64
+ doc = json.loads(line)
65
+ text = doc['text']
66
+ timestamp = int(doc['timestamp'])
67
+ docs.append({"text": text, "timestamp": timestamp})
68
+
69
+ tokens = tokenize(text)
70
+ token_set = set(tokens)
71
+ seen_words = set()
72
+
73
+ # Match all lemma queries using surface forms
74
+ for lemma in vocab_unigrams:
75
+ surface_forms = lemma_to_forms.get(lemma, set())
76
+ if token_set & surface_forms:
77
+ index[lemma][timestamp].append(doc_id)
78
+ seen_words.add(lemma)
79
+
80
+ for bigram in vocab_bigrams:
81
+ if bigram not in seen_words and has_bigram(tokens, bigram):
82
+ index[bigram][timestamp].append(doc_id)
83
+ seen_words.add(bigram)
84
+
85
+ global_seen_words.update(seen_words)
86
+
87
+ if (doc_id + 1) % 500 == 0:
88
+ missing = vocab_set - global_seen_words
89
+ print(f"[INFO] After {doc_id+1} docs, {len(missing)} vocab words still not seen.")
90
+ print("Example missing words:", list(missing)[:5])
91
+
92
+ missing_final = vocab_set - global_seen_words
93
+ if missing_final:
94
+ print(f"[WARNING] {len(missing_final)} vocab words were never found in any document.")
95
+ print("Examples:", list(missing_final)[:10])
96
+
97
+ return index, docs, lemma_to_forms
98
+
99
+ def save_index_to_disk(index, index_path):
100
+ index_clean = {
101
+ word: {str(ts): doc_ids for ts, doc_ids in ts_dict.items()}
102
+ for word, ts_dict in index.items()
103
+ }
104
+ os.makedirs(os.path.dirname(index_path), exist_ok=True)
105
+ with open(index_path, "w", encoding='utf-8') as f:
106
+ json.dump(index_clean, f, ensure_ascii=False)
107
+
108
+ def load_index_from_disk(index_path):
109
+ with open(index_path, 'r', encoding='utf-8') as f:
110
+ raw_index = json.load(f)
111
+
112
+ index = defaultdict(lambda: defaultdict(list))
113
+ for word, ts_dict in raw_index.items():
114
+ for ts, doc_ids in ts_dict.items():
115
+ index[word][int(ts)] = doc_ids
116
+
117
+ return index
118
+
119
+ def load_docs(docs_file_path):
120
+ docs = []
121
+ with open(docs_file_path, 'r', encoding='utf-8') as f:
122
+ for line in f:
123
+ doc = json.loads(line)
124
+ docs.append({
125
+ "text": doc["text"],
126
+ "timestamp": int(doc["timestamp"])
127
+ })
128
+ return docs
129
+
130
+ def load_index(docs_file_path, vocab, index_path=None, lemma_map_path=None):
131
+ if index_path and os.path.exists(index_path):
132
+ index = load_index_from_disk(index_path)
133
+ docs = load_docs(docs_file_path)
134
+ lemma_to_forms = build_inverse_lemma_map(docs_file_path, cache_path=lemma_map_path)
135
+ return index, docs, lemma_to_forms
136
+
137
+ index, docs, lemma_to_forms = build_inverted_index(
138
+ docs_file_path,
139
+ set(vocab),
140
+ lemma_map_path=lemma_map_path
141
+ )
142
+
143
+ if index_path:
144
+ save_index_to_disk(index, index_path)
145
+
146
+ return index, docs, lemma_to_forms
backend/inference/peak_detector.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from scipy.signal import find_peaks
3
+
4
+ def detect_peaks(trend, prominence=0.001, distance=2):
5
+ """
6
+ Detect peaks in a word's trend over time.
7
+
8
+ Args:
9
+ trend: List or np.array of floats (word importance over time)
10
+ prominence: Required prominence of peaks (tune based on scale)
11
+ distance: Minimum distance between peaks
12
+
13
+ Returns:
14
+ List of indices (timestamps) where peaks occur
15
+ """
16
+ trend = np.array(trend)
17
+ peaks, _ = find_peaks(trend, prominence=prominence, distance=distance)
18
+ return peaks.tolist()
backend/inference/process_beta.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import json
3
+
4
+ def load_beta_matrix(beta_path: str, vocab_path: str):
5
+ """
6
+ Loads the beta matrix (T x K x V) and vocab list.
7
+
8
+ Returns:
9
+ beta: np.ndarray of shape (T, K, V)
10
+ vocab: list of words
11
+ """
12
+ beta = np.load(beta_path) # shape: T x K x V
13
+ with open(vocab_path, 'r') as f:
14
+ vocab = [line.strip() for line in f.readlines()]
15
+ return beta, vocab
16
+
17
+ def get_top_words_at_time(beta, vocab, topic_id, time, top_n):
18
+ topic_beta = beta[time, topic_id, :]
19
+ top_indices = topic_beta.argsort()[-top_n:][::-1]
20
+ return [vocab[i] for i in top_indices]
21
+
22
+ def get_top_words_over_time(beta, vocab, topic_id, top_n):
23
+ topic_beta = beta[:, topic_id, :]
24
+ mean_beta = topic_beta.mean(axis=0)
25
+ top_indices = mean_beta.argsort()[-top_n:][::-1]
26
+ return [vocab[i] for i in top_indices]
27
+
28
+ def load_time_labels(time2id_path):
29
+ with open(time2id_path, 'r') as f:
30
+ time2id = json.load(f)
31
+ # Invert and sort by id
32
+ id2time = {v: k for k, v in time2id.items()}
33
+ return [id2time[i] for i in sorted(id2time)]
backend/inference/word_selector.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from scipy.special import softmax
3
+
4
+ def get_interesting_words(beta, vocab, topic_id, top_k_final=10, restrict_to=None):
5
+ """
6
+ Suggests interesting words by prioritizing "bursty" or "emerging" terms,
7
+ making it effective at capturing important low-probability words.
8
+
9
+ This algorithm focuses on the ratio of a word's peak probability to its mean,
10
+ capturing words that show significant growth or have a sudden moment of high
11
+ relevance, even if their average probability is low.
12
+
13
+ Parameters:
14
+ - beta: np.ndarray (T, K, V) - Topic-word distributions for each timestamp.
15
+ - vocab: list of V words - The vocabulary.
16
+ - topic_id: int - The ID of the topic to analyze.
17
+ - top_k_final: int - The number of words to return.
18
+ - restrict_to: optional list of str - Restricts scoring to a subset of words.
19
+
20
+ Returns:
21
+ - list of top_k_final interesting words (strings).
22
+ """
23
+ T, K, V = beta.shape
24
+
25
+ # --- 1. Detect whether softmax is needed ---
26
+ row_sums = beta.sum(axis=2)
27
+ is_prob_dist = np.allclose(row_sums, 1.0, atol=1e-2)
28
+
29
+ if not is_prob_dist:
30
+ print("🔁 Beta is not normalized — applying softmax across words per topic.")
31
+ beta = softmax(beta / 1e-3, axis=2)
32
+
33
+ # --- 2. Now extract normalized topic slice ---
34
+ topic_beta = beta[:, topic_id, :] # Shape: (T, V)
35
+
36
+ # Mean and Peak probability within the topic for each word
37
+ mean_topic = topic_beta.mean(axis=0) # Shape: (V,)
38
+ peak_topic = topic_beta.max(axis=0) # Shape: (V,)
39
+
40
+ # Corpus-wide mean for baseline comparison
41
+ mean_all = beta.mean(axis=(0, 1)) # Shape: (V,)
42
+
43
+ # Epsilon to prevent division by zero for words that never appear
44
+ epsilon = 1e-9
45
+
46
+ # --- 3. Calculate the three core components of the new score ---
47
+
48
+ # a) Burstiness Score: How much a word's peak stands out from its own average.
49
+ # This is the key to finding "surprising" words.
50
+ burstiness_score = peak_topic / (mean_topic + epsilon)
51
+
52
+ # b) Peak Specificity: How much the word's peak in this topic stands out from
53
+ # its average presence in the entire corpus.
54
+ peak_specificity_score = peak_topic / (mean_all + epsilon)
55
+
56
+ # c) Uniqueness Score (same as before): Penalizes words active in many topics.
57
+ active_in_topics = (beta > 1e-5).mean(axis=0) # Shape: (K, V)
58
+ idf_like = np.log((K + 1) / (active_in_topics.sum(axis=0) + 1)) # Shape: (V,)
59
+
60
+ # --- 4. Compute Final Interestingness Score ---
61
+ # This score is high for words that are unique, have a high peak relative
62
+ # to their baseline, and whose peak is an unusual event for that word.
63
+ final_scores = burstiness_score * peak_specificity_score * idf_like
64
+
65
+ # --- 5. Rank and select top words ---
66
+ if restrict_to is not None:
67
+ restrict_set = set(restrict_to)
68
+ word_indices = [i for i, w in enumerate(vocab) if w in restrict_set]
69
+ else:
70
+ word_indices = np.arange(V)
71
+
72
+ if not word_indices:
73
+ return []
74
+
75
+ # Rank the filtered indices by the final score in descending order
76
+ sorted_indices = sorted(word_indices, key=lambda i: -final_scores[i])
77
+
78
+ return [vocab[i] for i in sorted_indices[:top_k_final]]
79
+
80
+
81
+ def get_word_trend(beta, vocab, word, topic_id):
82
+ """
83
+ Get the time trend of a word's probability under a specific topic.
84
+
85
+ Args:
86
+ beta: np.ndarray of shape (T, K, V)
87
+ vocab: list of vocab words
88
+ word: word to search
89
+ topic_id: index of topic to inspect (0 <= topic_id < K)
90
+
91
+ Returns:
92
+ List of word probabilities over time (length T)
93
+ """
94
+ T, K, V = beta.shape
95
+ if word not in vocab:
96
+ raise ValueError(f"Word '{word}' not found in vocab.")
97
+ if not (0 <= topic_id < K):
98
+ raise ValueError(f"Invalid topic_id {topic_id}. Must be between 0 and {K - 1}.")
99
+
100
+ word_index = vocab.index(word)
101
+ trend = beta[:, topic_id, word_index] # shape (T,)
102
+ return trend.tolist()
backend/llm/custom_gemini.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_google_genai import ChatGoogleGenerativeAI
2
+ from langchain_core.messages import AIMessage, HumanMessage
3
+ from langchain_core.language_models.chat_models import BaseChatModel
4
+ from typing import List
5
+
6
+
7
+ class ChatGemini(BaseChatModel):
8
+ def __init__(self, api_key: str, model: str = "gemini-pro", temperature: float = 0.7):
9
+ self.model = model
10
+ self.temperature = temperature
11
+ self.api_key = api_key
12
+ self.client = ChatGoogleGenerativeAI(
13
+ model=model,
14
+ temperature=temperature,
15
+ google_api_key=api_key
16
+ )
17
+
18
+ def _generate(self, messages: List, stop: List[str] = None):
19
+ # Convert LangChain messages to string
20
+ prompt = "\n".join(
21
+ msg.content for msg in messages if isinstance(msg, (HumanMessage, AIMessage))
22
+ )
23
+ response = self.client.invoke(prompt)
24
+ return response
25
+
26
+ @property
27
+ def _llm_type(self) -> str:
28
+ return "gemini"
backend/llm/custom_mistral.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.language_models.chat_models import BaseChatModel
2
+ from langchain_core.messages import HumanMessage, AIMessage
3
+ from langchain_core.outputs import ChatResult, ChatGeneration
4
+ import requests
5
+ import os
6
+
7
+ class ChatMistral(BaseChatModel):
8
+ def __init__(self, hf_token=None, model_url=None):
9
+ self.hf_token = hf_token or os.getenv("HF_TOKEN")
10
+ self.model_url = model_url or "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.1"
11
+ self.headers = {"Authorization": f"Bearer {self.hf_token}"}
12
+
13
+ def _call(self, prompt: str) -> str:
14
+ response = requests.post(
15
+ self.model_url,
16
+ headers=self.headers,
17
+ json={"inputs": prompt, "parameters": {"max_new_tokens": 256}},
18
+ )
19
+ return response.json()[0]["generated_text"]
20
+
21
+ def invoke(self, messages, **kwargs):
22
+ prompt = "\n".join([msg.content for msg in messages if isinstance(msg, HumanMessage)])
23
+ response = self._call(prompt)
24
+ return AIMessage(content=response)
25
+
26
+ def _generate(self, messages, stop=None, **kwargs) -> ChatResult:
27
+ return ChatResult(generations=[ChatGeneration(message=self.invoke(messages))])
backend/llm/llm_router.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_anthropic import ChatAnthropic
2
+ from backend.llm.custom_mistral import ChatMistral
3
+ from langchain_google_genai import ChatGoogleGenerativeAI
4
+ from langchain_openai import ChatOpenAI
5
+ import os
6
+ import google.auth.transport.requests
7
+ import requests
8
+
9
+ resp = requests.get("https://www.google.com", proxies={
10
+ "http": os.getenv("http_proxy"),
11
+ "https": os.getenv("https_proxy")
12
+ })
13
+
14
+ def list_supported_models(provider=None):
15
+ if provider == "OpenAI":
16
+ return ["gpt-4.1-nano", "gpt-4o-mini"]
17
+ elif provider == "Anthropic":
18
+ return ["claude-3-opus-20240229", "claude-3-sonnet-20240229"]
19
+ elif provider == "Gemini":
20
+ return ["gemini-2.0-flash-lite", "gemini-1.5-flash"]
21
+ elif provider == "Mistral":
22
+ return ["mistral-small", "mistral-medium"]
23
+ else:
24
+ # Default fallback: all models grouped by provider
25
+ return {
26
+ "OpenAI": ["gpt-4.1-nano", "gpt-4o-mini"],
27
+ "Anthropic": ["claude-3-opus-20240229", "claude-3-sonnet-20240229"],
28
+ "Gemini": ["gemini-2.0-flash-lite", "gemini-1.5-flash"],
29
+ "Mistral": ["mistral-small", "mistral-medium"]
30
+ }
31
+
32
+
33
+ def get_llm(provider: str, model: str, api_key: str = None):
34
+ if provider == "OpenAI":
35
+ api_key = api_key or os.getenv("OPENAI_API_KEY")
36
+ if not api_key:
37
+ raise ValueError("Missing OpenAI API key.")
38
+ return ChatOpenAI(model_name=model, temperature=0, openai_api_key=api_key)
39
+
40
+ elif provider == "Anthropic":
41
+ api_key = api_key or os.getenv("ANTHROPIC_API_KEY")
42
+ if not api_key:
43
+ raise ValueError("Missing Anthropic API key.")
44
+ return ChatAnthropic(model=model, temperature=0, anthropic_api_key=api_key)
45
+
46
+ elif provider == "Gemini":
47
+ api_key = api_key or os.getenv("GEMINI_API_KEY")
48
+ if not api_key:
49
+ raise ValueError("Missing Gemini API key.")
50
+ # --- Patch: Set proxy if available ---
51
+ if "HTTP_PROXY" in os.environ or "http_proxy" in os.environ:
52
+
53
+ proxies = {
54
+ "http": os.getenv("http_proxy") or os.getenv("HTTP_PROXY"),
55
+ "https": os.getenv("https_proxy") or os.getenv("HTTPS_PROXY")
56
+ }
57
+
58
+ google.auth.transport.requests.requests.Request = lambda *args, **kwargs: requests.Request(
59
+ *args, **kwargs, proxies=proxies
60
+ )
61
+
62
+ return ChatGoogleGenerativeAI(model=model, temperature=0, google_api_key=api_key)
63
+
64
+
65
+ elif provider == "Mistral":
66
+ api_key = api_key or os.getenv("MISTRAL_API_KEY")
67
+ if not api_key:
68
+ raise ValueError("Missing Mistral API key.")
69
+ return ChatMistral(model=model, temperature=0, mistral_api_key=api_key)
70
+
71
+ else:
72
+ raise ValueError(f"Unsupported provider: {provider}")
73
+
backend/llm_utils/label_generator.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from hashlib import sha256
2
+ import json
3
+ from langchain_core.prompts import ChatPromptTemplate
4
+ from langchain_core.output_parsers import StrOutputParser
5
+ from typing import Optional
6
+ import os
7
+
8
+ #get_top_words_at_time
9
+ from backend.inference.process_beta import get_top_words_at_time
10
+
11
+ def label_topic_temporal(word_trajectory_str: str, llm, cache_path: Optional[str] = None) -> str:
12
+ """
13
+ Label a dynamic topic by providing the LLM with the top words over time.
14
+
15
+ Args:
16
+ word_trajectory_str (str): Formatted keyword evolution string.
17
+ llm: LangChain-compatible LLM instance.
18
+ cache_path (Optional[str]): Path to the cache file (JSON).
19
+
20
+ Returns:
21
+ str: Short label for the topic.
22
+ """
23
+ topic_key = sha256(word_trajectory_str.encode()).hexdigest()
24
+
25
+ # Load cache
26
+ if cache_path is not None and os.path.exists(cache_path):
27
+ with open(cache_path, "r") as f:
28
+ label_cache = json.load(f)
29
+ else:
30
+ label_cache = {}
31
+
32
+ # Return cached result
33
+ if topic_key in label_cache:
34
+ return label_cache[topic_key]
35
+
36
+ # Prompt template
37
+ prompt = ChatPromptTemplate.from_template(
38
+ "You are an expert in topic modeling and temporal data analysis. "
39
+ "Given the top words for a topic across multiple time points, your task is to return a short, specific, descriptive topic label. "
40
+ "Avoid vague, generic, or overly broad labels. Focus on consistent themes in the top words over time. "
41
+ "Use concise noun phrases, 2–5 words max. Do NOT include any explanation, justification, or extra output.\n\n"
42
+ "Top words over time:\n{trajectory}\n\n"
43
+ "Return ONLY the label (no quotes, no extra text):"
44
+ )
45
+ chain = prompt | llm | StrOutputParser()
46
+
47
+ try:
48
+ label = chain.invoke({"trajectory": word_trajectory_str}).strip()
49
+ except Exception as e:
50
+ label = "Unknown Topic"
51
+ print(f"[Labeling Error] {e}")
52
+
53
+ # Update cache and save
54
+ label_cache[topic_key] = label
55
+ if cache_path is not None:
56
+ os.makedirs(os.path.dirname(cache_path), exist_ok=True)
57
+ with open(cache_path, "w") as f:
58
+ json.dump(label_cache, f, indent=2)
59
+
60
+ return label
61
+
62
+
63
+ def get_topic_labels(beta, vocab, time_labels, llm, cache_path):
64
+ topic_labels = {}
65
+ for topic_id in range(beta.shape[1]):
66
+ word_trajectory_str = "\n".join([
67
+ f"{time_labels[t]}: {', '.join(get_top_words_at_time(beta, vocab, topic_id, t, top_n=10))}"
68
+ for t in range(beta.shape[0])
69
+ ])
70
+ label = label_topic_temporal(word_trajectory_str, llm=llm, cache_path=cache_path)
71
+ topic_labels[topic_id] = label
72
+ return topic_labels
backend/llm_utils/summarizer.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import numpy as np
3
+ from sklearn.metrics.pairwise import cosine_similarity
4
+ from sentence_transformers import SentenceTransformer
5
+ import faiss
6
+
7
+ from langchain.prompts import ChatPromptTemplate
8
+ from langchain.docstore.document import Document
9
+ from langchain.memory import ConversationBufferMemory
10
+ from langchain.chains import ConversationChain
11
+
12
+ import os
13
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
14
+
15
+
16
+ # --- MMR Utilities ---
17
+ def build_mmr_index(docs):
18
+ texts = [doc['text'] for doc in docs if 'text' in doc]
19
+ documents = [Document(page_content=text) for text in texts]
20
+
21
+ model = SentenceTransformer("all-MiniLM-L6-v2")
22
+ embeddings = model.encode([doc.page_content for doc in documents], convert_to_numpy=True)
23
+ faiss.normalize_L2(embeddings)
24
+
25
+ index = faiss.IndexFlatIP(embeddings.shape[1])
26
+ index.add(embeddings)
27
+
28
+ return model, index, embeddings, documents
29
+
30
+ def get_mmr_sample(model, index, embeddings, documents, query, k=15, lambda_mult=0.7):
31
+ if len(documents) == 0:
32
+ print("Warning: No documents available, returning empty list.")
33
+ return []
34
+
35
+ if len(documents) <= k:
36
+ print(f"Warning: Only {len(documents)} documents available, returning all.")
37
+ return documents
38
+
39
+ else:
40
+ query_vec = model.encode(query, convert_to_numpy=True)
41
+ query_vec = query_vec / np.linalg.norm(query_vec)
42
+
43
+ # Get candidate indices from FAISS (k * 4 or less if not enough documents)
44
+ num_candidates = min(k * 4, len(documents))
45
+ D, I = index.search(np.expand_dims(query_vec, axis=0), num_candidates)
46
+ candidate_idxs = list(I[0])
47
+
48
+ selected = []
49
+ while len(selected) < k and candidate_idxs:
50
+ if not selected:
51
+ selected.append(candidate_idxs.pop(0))
52
+ continue
53
+
54
+ mmr_scores = []
55
+ for idx in candidate_idxs:
56
+ relevance = cosine_similarity([query_vec], [embeddings[idx]])[0][0]
57
+ diversity = max([
58
+ cosine_similarity([embeddings[idx]], [embeddings[sel]])[0][0]
59
+ for sel in selected
60
+ ])
61
+ mmr_score = lambda_mult * relevance - (1 - lambda_mult) * diversity
62
+ mmr_scores.append((idx, mmr_score))
63
+
64
+ next_best = max(mmr_scores, key=lambda x: x[1])[0]
65
+ selected.append(next_best)
66
+ candidate_idxs.remove(next_best)
67
+
68
+ return [documents[i] for i in selected]
69
+
70
+
71
+ # --- Summarization ---
72
+ def summarize_docs(word, timestamp, docs, llm, k):
73
+ if not docs:
74
+ return "No documents available for this word at this time.", [], 0
75
+
76
+ try:
77
+ model, index, embeddings, documents = build_mmr_index(docs)
78
+ mmr_docs = get_mmr_sample(model, index, embeddings, documents, query=word, k=k)
79
+
80
+ context_texts = "\n".join(f"- {doc.page_content}" for doc in mmr_docs)
81
+
82
+ prompt_template = ChatPromptTemplate.from_template(
83
+ "Given the following documents from {timestamp} containing the word '{word}', "
84
+ "identify the key themes or distinct discussion points that were prevalent during that time. "
85
+ "Do NOT describe each bullet in detail. Be concise. Each bullet should be a short phrase or sentence "
86
+ "capturing a unique, non-overlapping theme. Avoid any elaboration, examples, or justification.\n\n"
87
+ "Return no more than 5–7 bullets.\n\n"
88
+ "{context_texts}\n\nSummary:"
89
+ )
90
+
91
+ chain = prompt_template | llm
92
+ summary = chain.invoke({
93
+ "word": word,
94
+ "timestamp": timestamp,
95
+ "context_texts": context_texts
96
+ }).content.strip()
97
+
98
+ return summary, mmr_docs
99
+
100
+ except Exception as e:
101
+ return f"[Error summarizing: {e}]", [], 0
102
+
103
+
104
+ def summarize_multiword_docs(words, timestamp, docs, llm, k):
105
+ if not docs:
106
+ return "No common documents available for these words at this time.", []
107
+
108
+ try:
109
+ model, index, embeddings, documents = build_mmr_index(docs)
110
+ query = " ".join(words)
111
+ mmr_docs = get_mmr_sample(model, index, embeddings, documents, query=query, k=k)
112
+
113
+ context_texts = "\n".join(f"- {doc.page_content}" for doc in mmr_docs)
114
+
115
+ prompt_template = ChatPromptTemplate.from_template(
116
+ "Given the following documents from {timestamp} that all mention the words: '{word_list}', "
117
+ "identify the key themes or distinct discussion points that were prevalent during that time. "
118
+ "Do NOT describe each bullet in detail. Be concise. Each bullet should be a short phrase or sentence "
119
+ "capturing a unique, non-overlapping theme. Avoid any elaboration, examples, or justification.\n\n"
120
+ "Return no more than 5–7 bullets.\n\n"
121
+ "{context_texts}\n\n"
122
+ "Concise Thematic Summary:"
123
+ )
124
+
125
+ chain = prompt_template | llm
126
+ summary = chain.invoke({
127
+ "word_list": ", ".join(words),
128
+ "timestamp": timestamp,
129
+ "context_texts": context_texts
130
+ }).content.strip()
131
+
132
+ return summary, mmr_docs
133
+
134
+ except Exception as e:
135
+ return f"[Error summarizing: {e}]", []
136
+
137
+
138
+ # --- Follow-up Question Handler (Improved) ---
139
+ def ask_multiturn_followup(history: list, question: str, llm, context_texts: str) -> str:
140
+ """
141
+ Handles multi-turn follow-up questions based on a provided set of documents.
142
+
143
+ This function now REQUIRES context_texts to be provided, ensuring the LLM
144
+ is always grounded in the source documents for follow-up questions.
145
+
146
+ Args:
147
+ history (list): A list of dictionaries representing the conversation history
148
+ (e.g., [{"role": "user", "content": "..."}]).
149
+ question (str): The user's new follow-up question.
150
+ llm: The initialized language model instance.
151
+ context_texts (str): A single string containing all the numbered documents
152
+ for context.
153
+
154
+ Returns:
155
+ str: The AI's response to the follow-up question.
156
+ """
157
+ try:
158
+ # 1. Reconstruct conversation memory from the history provided from the UI
159
+ memory = ConversationBufferMemory(return_messages=True)
160
+ for turn in history:
161
+ if turn["role"] == "user":
162
+ memory.chat_memory.add_user_message(turn["content"])
163
+ elif turn["role"] == "assistant":
164
+ memory.chat_memory.add_ai_message(turn["content"])
165
+
166
+ # 2. Define the system instruction that grounds the LLM
167
+ system_instruction = (
168
+ "You are an assistant answering questions strictly based on the provided sample documents below. "
169
+ "Your memory contains the previous turns of this conversation. "
170
+ "If the answer is not clearly available in the text, respond with: "
171
+ "'The information is not available in the documents provided.'\n\n"
172
+ )
173
+
174
+ # 3. Create the full prompt. No more conditional logic, as context is required.
175
+ # The `ConversationChain` will automatically use the memory, so we only need
176
+ # to provide the current input, which includes the grounding documents.
177
+ full_prompt = (
178
+ f"{system_instruction}"
179
+ f"--- DOCUMENTS ---\n{context_texts.strip()}\n\n"
180
+ f"--- QUESTION ---\n{question}"
181
+ )
182
+
183
+ # 4. Create and run the conversation chain
184
+ conversation = ConversationChain(llm=llm, memory=memory, verbose=False)
185
+ response = conversation.predict(input=full_prompt)
186
+
187
+ return response.strip()
188
+
189
+ except Exception as e:
190
+ # Good practice to log the full exception for easier debugging
191
+ print(f"[ERROR] in ask_multiturn_followup: {e}")
192
+ return f"[Error during multi-turn follow-up. Please check the logs.]"
backend/llm_utils/token_utils.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal
2
+ import tiktoken
3
+ import anthropic
4
+ from typing import List
5
+
6
+ # Gemini requires the Vertex AI SDK
7
+ try:
8
+ from vertexai.preview import tokenization as vertex_tokenization
9
+ except ImportError:
10
+ vertex_tokenization = None
11
+
12
+ # Mistral requires the SentencePiece tokenizer
13
+ try:
14
+ import sentencepiece as spm
15
+ except ImportError:
16
+ spm = None
17
+
18
+ # ---------------------------
19
+ # Individual Token Counters
20
+ # ---------------------------
21
+
22
+ def count_tokens_openai(text: str, model_name: str) -> int:
23
+ try:
24
+ encoding = tiktoken.encoding_for_model(model_name)
25
+ except KeyError:
26
+ encoding = tiktoken.get_encoding("cl100k_base") # fallback
27
+ return len(encoding.encode(text))
28
+
29
+ def count_tokens_anthropic(text: str, model_name: str) -> int:
30
+ try:
31
+ client = anthropic.Anthropic()
32
+ response = client.messages.count_tokens(
33
+ model=model_name,
34
+ messages=[{"role": "user", "content": text}]
35
+ )
36
+ return response['input_tokens']
37
+ except Exception as e:
38
+ raise RuntimeError(f"Anthropic token counting failed: {e}")
39
+
40
+ def count_tokens_gemini(text: str, model_name: str) -> int:
41
+ if vertex_tokenization is None:
42
+ raise ImportError("Please install vertexai: pip install google-cloud-aiplatform[tokenization]")
43
+ try:
44
+ tokenizer = vertex_tokenization.get_tokenizer_for_model("gemini-1.5-flash-002")
45
+ result = tokenizer.count_tokens(text)
46
+ return result.total_tokens
47
+ except Exception as e:
48
+ raise RuntimeError(f"Gemini token counting failed: {e}")
49
+
50
+ def count_tokens_mistral(text: str) -> int:
51
+ if spm is None:
52
+ raise ImportError("Please install sentencepiece: pip install sentencepiece")
53
+ try:
54
+ sp = spm.SentencePieceProcessor()
55
+ # IMPORTANT: You must provide the correct path to the tokenizer model file
56
+ sp.load("mistral_tokenizer.model")
57
+ tokens = sp.encode(text, out_type=str)
58
+ return len(tokens)
59
+ except Exception as e:
60
+ raise RuntimeError(f"Mistral token counting failed: {e}")
61
+
62
+ # ---------------------------
63
+ # Unified Token Counter
64
+ # ---------------------------
65
+
66
+ def count_tokens(text: str, model_name: str, provider: Literal["OpenAI", "Anthropic", "Gemini", "Mistral"]) -> int:
67
+ if provider == "OpenAI":
68
+ return count_tokens_openai(text, model_name)
69
+ elif provider == "Anthropic":
70
+ return count_tokens_anthropic(text, model_name)
71
+ elif provider == "Gemini":
72
+ return count_tokens_gemini(text, model_name)
73
+ elif provider == "Mistral":
74
+ return count_tokens_mistral(text)
75
+ else:
76
+ raise ValueError(f"Unsupported provider: {provider}")
77
+
78
+
79
+ def get_token_limit_for_model(model_name, provider):
80
+ # Example values; update as needed for your providers
81
+ if provider == "openai":
82
+ if "gpt-4.1-nano" in model_name:
83
+ return 1047576 # Based on search results
84
+ elif "gpt-4o-mini" in model_name:
85
+ return 128000 # Based on search results
86
+ elif provider == "anthropic":
87
+ if "claude-3-opus" in model_name:
88
+ return 200000 # Based on search results
89
+ elif "claude-3-sonnet" in model_name:
90
+ return 200000 # Based on search results
91
+ elif provider == "gemini":
92
+ if "gemini-2.0-flash-lite" in model_name:
93
+ return 1048576 # Based on search results
94
+ elif "gemini-1.5-flash" in model_name:
95
+ return 1048576 # Based on search results
96
+ elif provider == "mistral":
97
+ if "mistral-small" in model_name:
98
+ return 32000 # Based on search results
99
+ elif "mistral-medium" in model_name:
100
+ return 32000 # Based on search results
101
+ return 8000 # default fallback
102
+
103
+
104
+ def estimate_avg_tokens_per_doc(
105
+ docs: List[str],
106
+ model_name: str,
107
+ provider: Literal["OpenAI", "Anthropic", "Gemini", "Mistral"]
108
+ ) -> float:
109
+ """
110
+ Estimate the average number of tokens per document for the given model.
111
+
112
+ Args:
113
+ docs (List[str]): List of documents.
114
+ model_name (str): Model name.
115
+ provider (Literal): LLM provider.
116
+
117
+ Returns:
118
+ float: Average number of tokens per document.
119
+ """
120
+ if not docs:
121
+ return 0.0
122
+ token_counts = [count_tokens(doc, model_name, provider) for doc in docs]
123
+ return sum(token_counts) / len(token_counts)
124
+
125
+ def estimate_max_k(
126
+ docs: List[str],
127
+ model_name: str,
128
+ provider: Literal["OpenAI", "Anthropic", "Gemini", "Mistral"],
129
+ margin_ratio: float = 0.1,
130
+ ) -> int:
131
+ """
132
+ Estimate the maximum number of documents that can fit in the context window.
133
+
134
+ Returns:
135
+ int: Estimated K.
136
+ """
137
+ if not docs:
138
+ return 0
139
+
140
+ max_tokens = get_token_limit_for_model(model_name, provider)
141
+ margin = int(max_tokens * margin_ratio)
142
+ available_tokens = max_tokens - margin
143
+
144
+ avg_tokens_per_doc = estimate_avg_tokens_per_doc(docs, model_name, provider)
145
+ if avg_tokens_per_doc == 0:
146
+ return 0
147
+
148
+ return min(len(docs), int(available_tokens // avg_tokens_per_doc))
149
+
150
+ def estimate_max_k_fast(docs, margin_ratio=0.1, max_tokens=8000, model_name="gpt-3.5-turbo"):
151
+ enc = tiktoken.encoding_for_model(model_name)
152
+ avg_len = sum(len(enc.encode(doc)) for doc in docs[:20]) / min(20, len(docs))
153
+ margin = int(max_tokens * margin_ratio)
154
+ available = max_tokens - margin
155
+ return min(len(docs), int(available // avg_len))
156
+
157
+ def estimate_k_max_from_word_stats(
158
+ avg_words_per_doc: float,
159
+ margin_ratio: float = 0.1,
160
+ avg_tokens_per_word: float = 1.3,
161
+ model_name=None,
162
+ provider=None
163
+ ) -> int:
164
+ model_token_limit = get_token_limit_for_model(model_name, provider)
165
+ effective_limit = int(model_token_limit * (1 - margin_ratio))
166
+ est_tokens_per_doc = avg_words_per_doc * avg_tokens_per_word
167
+ return int(effective_limit // est_tokens_per_doc)
backend/models/CFDTM/CFDTM.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from .ETC import ETC
7
+ from .UWE import UWE
8
+ from .Encoder import MLPEncoder
9
+
10
+
11
+ class CFDTM(nn.Module):
12
+ '''
13
+ Modeling Dynamic Topics in Chain-Free Fashion by Evolution-Tracking Contrastive Learning and Unassociated Word Exclusion. ACL 2024 Findings
14
+
15
+ Xiaobao Wu, Xinshuai Dong, Liangming Pan, Thong Nguyen, Anh Tuan Luu.
16
+ '''
17
+
18
+ def __init__(self,
19
+ vocab_size,
20
+ train_time_wordfreq,
21
+ num_times,
22
+ pretrained_WE=None,
23
+ num_topics=50,
24
+ en_units=100,
25
+ temperature=0.1,
26
+ beta_temp=1.0,
27
+ weight_neg=1.0e+7,
28
+ weight_pos=1.0e+1,
29
+ weight_UWE=1.0e+3,
30
+ neg_topk=15,
31
+ dropout=0.,
32
+ embed_size=200
33
+ ):
34
+ super().__init__()
35
+
36
+ self.num_topics = num_topics
37
+ self.beta_temp = beta_temp
38
+ self.train_time_wordfreq = train_time_wordfreq
39
+ self.encoder = MLPEncoder(vocab_size, num_topics, en_units, dropout)
40
+
41
+ self.a = 1 * np.ones((1, num_topics)).astype(np.float32)
42
+ self.mu2 = nn.Parameter(torch.as_tensor((np.log(self.a).T - np.mean(np.log(self.a), 1)).T))
43
+ self.var2 = nn.Parameter(torch.as_tensor((((1.0 / self.a) * (1 - (2.0 / num_topics))).T + (1.0 / (num_topics * num_topics)) * np.sum(1.0 / self.a, 1)).T))
44
+
45
+ self.mu2.requires_grad = False
46
+ self.var2.requires_grad = False
47
+
48
+ self.decoder_bn = nn.BatchNorm1d(vocab_size, affine=False)
49
+
50
+ if pretrained_WE is None:
51
+ self.word_embeddings = nn.init.trunc_normal_(torch.empty(vocab_size, embed_size), std=0.1)
52
+ self.word_embeddings = nn.Parameter(F.normalize(self.word_embeddings))
53
+
54
+ else:
55
+ self.word_embeddings = nn.Parameter(torch.from_numpy(pretrained_WE).float())
56
+
57
+ # topic_embeddings: TxKxD
58
+ self.topic_embeddings = nn.init.xavier_normal_(torch.zeros(num_topics, self.word_embeddings.shape[1])).repeat(num_times, 1, 1)
59
+ self.topic_embeddings = nn.Parameter(self.topic_embeddings)
60
+
61
+ self.ETC = ETC(num_times, temperature, weight_neg, weight_pos)
62
+ self.UWE = UWE(self.ETC, num_times, temperature, weight_UWE, neg_topk)
63
+
64
+ def get_beta(self):
65
+ dist = self.pairwise_euclidean_dist(F.normalize(self.topic_embeddings, dim=-1), F.normalize(self.word_embeddings, dim=-1))
66
+ beta = F.softmax(-dist / self.beta_temp, dim=1)
67
+
68
+ return beta
69
+
70
+ def pairwise_euclidean_dist(self, x, y):
71
+ cost = torch.sum(x ** 2, axis=-1, keepdim=True) + torch.sum(y ** 2, axis=-1) - 2 * torch.matmul(x, y.t())
72
+ return cost
73
+
74
+ def get_theta(self, x, times=None):
75
+ theta, mu, logvar = self.encoder(x)
76
+ if self.training:
77
+ return theta, mu, logvar
78
+
79
+ return theta
80
+
81
+ def get_KL(self, mu, logvar):
82
+ var = logvar.exp()
83
+ var_division = var / self.var2
84
+ diff = mu - self.mu2
85
+ diff_term = diff * diff / self.var2
86
+ logvar_division = self.var2.log() - logvar
87
+ KLD = 0.5 * ((var_division + diff_term + logvar_division).sum(axis=1) - self.num_topics)
88
+
89
+ return KLD.mean()
90
+
91
+ def get_NLL(self, theta, beta, x, recon_x=None):
92
+ if recon_x is None:
93
+ recon_x = self.decode(theta, beta)
94
+ recon_loss = -(x * recon_x.log()).sum(axis=1)
95
+
96
+ return recon_loss
97
+
98
+ def decode(self, theta, beta):
99
+ d1 = F.softmax(self.decoder_bn(torch.bmm(theta.unsqueeze(1), beta).squeeze(1)), dim=-1)
100
+ return d1
101
+
102
+ def forward(self, x, times):
103
+ loss = 0.
104
+
105
+ theta, mu, logvar = self.get_theta(x)
106
+ kl_theta = self.get_KL(mu, logvar)
107
+
108
+ loss += kl_theta
109
+
110
+ beta = self.get_beta()
111
+ time_index_beta = beta[times]
112
+ recon_x = self.decode(theta, time_index_beta)
113
+ NLL = self.get_NLL(theta, time_index_beta, x, recon_x)
114
+ NLL = NLL.mean()
115
+ loss += NLL
116
+
117
+ loss_ETC = self.ETC(self.topic_embeddings)
118
+ loss += loss_ETC
119
+
120
+ loss_UWE = self.UWE(self.train_time_wordfreq, beta, self.topic_embeddings, self.word_embeddings)
121
+ loss += loss_UWE
122
+
123
+ rst_dict = {
124
+ 'loss': loss,
125
+ }
126
+
127
+ return rst_dict
backend/models/CFDTM/ETC.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class ETC(nn.Module):
7
+ def __init__(self, num_times, temperature, weight_neg, weight_pos):
8
+ super().__init__()
9
+ self.num_times = num_times
10
+ self.weight_neg = weight_neg
11
+ self.weight_pos = weight_pos
12
+ self.temperature = temperature
13
+
14
+ def forward(self, topic_embeddings):
15
+ loss = 0.
16
+ loss_neg = 0.
17
+ loss_pos = 0.
18
+
19
+ for t in range(self.num_times):
20
+ loss_neg += self.compute_loss(topic_embeddings[t], topic_embeddings[t], self.temperature, self_contrast=True)
21
+
22
+ for t in range(1, self.num_times):
23
+ loss_pos += self.compute_loss(topic_embeddings[t], topic_embeddings[t - 1].detach(), self.temperature, self_contrast=False, only_pos=True)
24
+
25
+ loss_neg *= (self.weight_neg / self.num_times)
26
+ loss_pos *= (self.weight_pos / (self.num_times - 1))
27
+ loss = loss_neg + loss_pos
28
+
29
+ return loss
30
+
31
+ def compute_loss(self, anchor_feature, contrast_feature, temperature, self_contrast=False, only_pos=False, all_neg=False):
32
+ # KxK
33
+ anchor_dot_contrast = torch.div(
34
+ torch.matmul(F.normalize(anchor_feature, dim=1), F.normalize(contrast_feature, dim=1).T),
35
+ temperature
36
+ )
37
+
38
+ logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
39
+ logits = anchor_dot_contrast - logits_max.detach()
40
+
41
+ pos_mask = torch.eye(anchor_dot_contrast.shape[0]).to(anchor_dot_contrast.device)
42
+
43
+ if self_contrast is False:
44
+ if only_pos is False:
45
+ if all_neg is True:
46
+ exp_logits = torch.exp(logits)
47
+ sum_exp_logits = exp_logits.sum(1)
48
+ log_prob = -torch.log(sum_exp_logits + 1e-12)
49
+
50
+ mean_log_prob = -log_prob.sum() / (logits.shape[0] * logits.shape[1])
51
+ else:
52
+ # only pos
53
+ mean_log_prob = -(logits * pos_mask).sum() / pos_mask.sum()
54
+ else:
55
+ # self contrast: push away from each other in the same time slice.
56
+ exp_logits = torch.exp(logits) * (1 - pos_mask)
57
+ sum_exp_logits = exp_logits.sum(1)
58
+ log_prob = -torch.log(sum_exp_logits + 1e-12)
59
+
60
+ mean_log_prob = -log_prob.sum() / (1 - pos_mask).sum()
61
+
62
+ return mean_log_prob
backend/models/CFDTM/Encoder.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class MLPEncoder(nn.Module):
7
+ def __init__(self, vocab_size, num_topic, hidden_dim, dropout):
8
+ super().__init__()
9
+
10
+ self.fc11 = nn.Linear(vocab_size, hidden_dim)
11
+ self.fc12 = nn.Linear(hidden_dim, hidden_dim)
12
+ self.fc21 = nn.Linear(hidden_dim, num_topic)
13
+ self.fc22 = nn.Linear(hidden_dim, num_topic)
14
+
15
+ self.fc1_drop = nn.Dropout(dropout)
16
+ self.z_drop = nn.Dropout(dropout)
17
+
18
+ self.mean_bn = nn.BatchNorm1d(num_topic, affine=True)
19
+ self.mean_bn.weight.requires_grad = False
20
+ self.logvar_bn = nn.BatchNorm1d(num_topic, affine=True)
21
+ self.logvar_bn.weight.requires_grad = False
22
+
23
+ def reparameterize(self, mu, logvar):
24
+ if self.training:
25
+ std = torch.exp(0.5 * logvar)
26
+ eps = torch.randn_like(std)
27
+ return mu + (eps * std)
28
+ else:
29
+ return mu
30
+
31
+ def forward(self, x):
32
+ e1 = F.softplus(self.fc11(x))
33
+ e1 = F.softplus(self.fc12(e1))
34
+ e1 = self.fc1_drop(e1)
35
+ mu = self.mean_bn(self.fc21(e1))
36
+ logvar = self.logvar_bn(self.fc22(e1))
37
+ theta = self.reparameterize(mu, logvar)
38
+ theta = F.softmax(theta, dim=1)
39
+ theta = self.z_drop(theta)
40
+ return theta, mu, logvar
backend/models/CFDTM/UWE.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class UWE(nn.Module):
6
+ def __init__(self, ETC, num_times, temperature, weight_UWE, neg_topk):
7
+ super().__init__()
8
+
9
+ self.ETC = ETC
10
+ self.weight_UWE = weight_UWE
11
+ self.num_times = num_times
12
+ self.temperature = temperature
13
+ self.neg_topk = neg_topk
14
+
15
+ def forward(self, time_wordcount, beta, topic_embeddings, word_embeddings):
16
+ assert(self.num_times == time_wordcount.shape[0])
17
+
18
+ topk_indices = self.get_topk_indices(beta)
19
+
20
+ loss_UWE = 0.
21
+ cnt_valid_times = 0.
22
+ for t in range(self.num_times):
23
+ neg_idx = torch.where(time_wordcount[t] == 0)[0]
24
+
25
+ time_topk_indices = topk_indices[t]
26
+ neg_idx = list(set(neg_idx.cpu().tolist()).intersection(set(time_topk_indices.cpu().tolist())))
27
+ neg_idx = torch.tensor(neg_idx).long().to(time_wordcount.device)
28
+
29
+ if len(neg_idx) == 0:
30
+ continue
31
+
32
+ time_neg_WE = word_embeddings[neg_idx]
33
+
34
+ # topic_embeddings[t]: K x D
35
+ # word_embeddings[neg_idx]: |V_{neg}| x D
36
+ loss_UWE += self.ETC.compute_loss(topic_embeddings[t], time_neg_WE, temperature=self.temperature, all_neg=True)
37
+ cnt_valid_times += 1
38
+
39
+ if cnt_valid_times > 0:
40
+ loss_UWE *= (self.weight_UWE / cnt_valid_times)
41
+
42
+ return loss_UWE
43
+
44
+ def get_topk_indices(self, beta):
45
+ # topk_indices: T x K x neg_topk
46
+ topk_indices = torch.topk(beta, k=self.neg_topk, dim=-1).indices
47
+ topk_indices = torch.flatten(topk_indices, start_dim=1)
48
+ return topk_indices
backend/models/CFDTM/__init__.py ADDED
File without changes
backend/models/CFDTM/__pycache__/CFDTM.cpython-39.pyc ADDED
Binary file (4.01 kB). View file
 
backend/models/CFDTM/__pycache__/ETC.cpython-39.pyc ADDED
Binary file (1.85 kB). View file
 
backend/models/CFDTM/__pycache__/Encoder.cpython-39.pyc ADDED
Binary file (1.52 kB). View file
 
backend/models/CFDTM/__pycache__/UWE.cpython-39.pyc ADDED
Binary file (1.56 kB). View file
 
backend/models/CFDTM/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (158 Bytes). View file
 
backend/models/DBERTopic_trainer.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from bertopic import BERTopic
3
+ from backend.datasets.utils import _utils
4
+ from backend.datasets.utils.logger import Logger
5
+
6
+ logger = Logger("WARNING")
7
+
8
+
9
+ class DBERTopicTrainer:
10
+ def __init__(self,
11
+ dataset,
12
+ num_topics=20,
13
+ num_top_words=15,
14
+ nr_bins=20,
15
+ global_tuning=True,
16
+ evolution_tuning=True,
17
+ datetime_format=None,
18
+ verbose=False):
19
+
20
+ self.dataset = dataset
21
+ self.docs = dataset.raw_documents
22
+ self.num_topics=num_topics
23
+ # self.timestamps = dataset.train_times
24
+ self.vocab = dataset.vocab
25
+ self.num_top_words = num_top_words
26
+ # self.nr_bins = nr_bins
27
+ # self.global_tuning = global_tuning
28
+ # self.evolution_tuning = evolution_tuning
29
+ # self.datetime_format = datetime_format
30
+ self.verbose = verbose
31
+
32
+ if verbose:
33
+ logger.set_level("DEBUG")
34
+ else:
35
+ logger.set_level("WARNING")
36
+
37
+ def train(self, timestamps, datetime_format='%Y'):
38
+ logger.info("Fitting BERTopic...")
39
+ self.model = BERTopic(nr_topics=self.num_topics, verbose=self.verbose)
40
+ self.topics, _ = self.model.fit_transform(self.docs)
41
+
42
+ logger.info("Running topics_over_time...")
43
+ self.topics_over_time_df = self.model.topics_over_time(
44
+ docs=self.docs,
45
+ timestamps=timestamps,
46
+ nr_bins=len(set(timestamps)),
47
+ datetime_format=datetime_format
48
+ )
49
+
50
+ self.unique_timestamps = sorted(self.topics_over_time_df["Timestamp"].unique())
51
+ self.unique_topics = sorted(self.topics_over_time_df["Topic"].unique())
52
+ self.vocab = self.model.vectorizer_model.get_feature_names_out()
53
+ self.V = len(self.vocab)
54
+ self.K = len(self.unique_topics)
55
+ self.T = len(self.unique_timestamps)
56
+
57
+ def get_beta(self):
58
+ logger.info("Generating β matrix...")
59
+
60
+ beta = np.zeros((self.T, self.K, self.V))
61
+ topic_to_index = {topic: idx for idx, topic in enumerate(self.unique_topics)}
62
+ timestamp_to_index = {timestamp: idx for idx, timestamp in enumerate(self.unique_timestamps)}
63
+
64
+ # Extract topic representations at each time
65
+ for t_idx, timestamp in enumerate(self.unique_timestamps):
66
+ selection = self.topics_over_time_df[self.topics_over_time_df["Timestamp"] == timestamp]
67
+ for _, row in selection.iterrows():
68
+ topic = row["Topic"]
69
+ words = row["Words"].split(", ")
70
+ if topic not in topic_to_index:
71
+ continue
72
+ k = topic_to_index[topic]
73
+ for word in words:
74
+ if word in self.vocab:
75
+ v = np.where(self.vocab == word)[0][0]
76
+ beta[t_idx, k, v] += 1.0
77
+
78
+ # Normalize each β_tk to be a probability distribution
79
+ beta = beta / (beta.sum(axis=2, keepdims=True) + 1e-10)
80
+ return beta
81
+
82
+ def get_top_words(self, num_top_words=None):
83
+ if num_top_words is None:
84
+ num_top_words = self.num_top_words
85
+ beta = self.get_beta()
86
+ top_words_list = list()
87
+ for time in range(beta.shape[0]):
88
+ top_words = _utils.get_top_words(beta[time], self.vocab, num_top_words, self.verbose)
89
+ top_words_list.append(top_words)
90
+ return top_words_list
91
+
92
+ def get_theta(self):
93
+ # Not applicable for BERTopic; can return topic assignments or soft topic distributions if required
94
+ logger.warning("get_theta is not implemented for BERTopic.")
95
+ return None
96
+
97
+ def export_theta(self):
98
+ logger.warning("export_theta is not implemented for BERTopic.")
99
+ return None, None
backend/models/DETM.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from torch import nn
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class DETM(nn.Module):
8
+ """
9
+ The Dynamic Embedded Topic Model. 2019
10
+
11
+ Adji B. Dieng, Francisco J. R. Ruiz, David M. Blei
12
+ """
13
+ def __init__(self, vocab_size, num_times, train_size, train_time_wordfreq,
14
+ num_topics=50, train_WE=True, pretrained_WE=None, en_units=800,
15
+ eta_hidden_size=200, rho_size=300, enc_drop=0.0, eta_nlayers=3,
16
+ eta_dropout=0.0, delta=0.005, theta_act='relu', device='cpu'):
17
+ super().__init__()
18
+
19
+ ## define hyperparameters
20
+ self.num_topics = num_topics
21
+ self.num_times = num_times
22
+ self.vocab_size = vocab_size
23
+ self.eta_hidden_size = eta_hidden_size
24
+ self.rho_size = rho_size
25
+ self.enc_drop = enc_drop
26
+ self.eta_nlayers = eta_nlayers
27
+ self.t_drop = nn.Dropout(enc_drop)
28
+ self.eta_dropout = eta_dropout
29
+ self.delta = delta
30
+ self.train_WE = train_WE
31
+ self.train_size = train_size
32
+ self.rnn_inp = train_time_wordfreq
33
+ self.device = device
34
+
35
+ self.theta_act = self.get_activation(theta_act)
36
+
37
+ ## define the word embedding matrix \rho
38
+ if self.train_WE:
39
+ self.rho = nn.Linear(self.rho_size, self.vocab_size, bias=False)
40
+ else:
41
+ rho = nn.Embedding(pretrained_WE.size())
42
+ rho.weight.data = torch.from_numpy(pretrained_WE)
43
+ self.rho = rho.weight.data.clone().float().to(self.device)
44
+
45
+ ## define the variational parameters for the topic embeddings over time (alpha) ... alpha is K x T x L
46
+ self.mu_q_alpha = nn.Parameter(torch.randn(self.num_topics, self.num_times, self.rho_size))
47
+ self.logsigma_q_alpha = nn.Parameter(torch.randn(self.num_topics, self.num_times, self.rho_size))
48
+
49
+ ## define variational distribution for \theta_{1:D} via amortizartion... theta is K x D
50
+ self.q_theta = nn.Sequential(
51
+ nn.Linear(self.vocab_size + self.num_topics, en_units),
52
+ self.theta_act,
53
+ nn.Linear(en_units, en_units),
54
+ self.theta_act,
55
+ )
56
+ self.mu_q_theta = nn.Linear(en_units, self.num_topics, bias=True)
57
+ self.logsigma_q_theta = nn.Linear(en_units, self.num_topics, bias=True)
58
+
59
+ ## define variational distribution for \eta via amortizartion... eta is K x T
60
+ self.q_eta_map = nn.Linear(self.vocab_size, self.eta_hidden_size)
61
+ self.q_eta = nn.LSTM(self.eta_hidden_size, self.eta_hidden_size, self.eta_nlayers, dropout=self.eta_dropout)
62
+ self.mu_q_eta = nn.Linear(self.eta_hidden_size + self.num_topics, self.num_topics, bias=True)
63
+ self.logsigma_q_eta = nn.Linear(self.eta_hidden_size + self.num_topics, self.num_topics, bias=True)
64
+
65
+ self.decoder_bn = nn.BatchNorm1d(vocab_size)
66
+ self.decoder_bn.weight.requires_grad = False
67
+
68
+ def get_activation(self, act):
69
+ activations = {
70
+ 'tanh': nn.Tanh(),
71
+ 'relu': nn.ReLU(),
72
+ 'softplus': nn.Softplus(),
73
+ 'rrelu': nn.RReLU(),
74
+ 'leakyrelu': nn.LeakyReLU(),
75
+ 'elu': nn.ELU(),
76
+ 'selu': nn.SELU(),
77
+ 'glu': nn.GLU(),
78
+ }
79
+
80
+ if act in activations:
81
+ act = activations[act]
82
+ else:
83
+ print('Defaulting to tanh activations...')
84
+ act = nn.Tanh()
85
+ return act
86
+
87
+ def reparameterize(self, mu, logvar):
88
+ """Returns a sample from a Gaussian distribution via reparameterization.
89
+ """
90
+ if self.training:
91
+ std = torch.exp(0.5 * logvar)
92
+ eps = torch.randn_like(std)
93
+ return eps.mul_(std).add_(mu)
94
+ else:
95
+ return mu
96
+
97
+ def get_kl(self, q_mu, q_logsigma, p_mu=None, p_logsigma=None):
98
+ """Returns KL( N(q_mu, q_logsigma) || N(p_mu, p_logsigma) ).
99
+ """
100
+ if p_mu is not None and p_logsigma is not None:
101
+ sigma_q_sq = torch.exp(q_logsigma)
102
+ sigma_p_sq = torch.exp(p_logsigma)
103
+ kl = ( sigma_q_sq + (q_mu - p_mu)**2 ) / ( sigma_p_sq + 1e-6 )
104
+ kl = kl - 1 + p_logsigma - q_logsigma
105
+ kl = 0.5 * torch.sum(kl, dim=-1)
106
+ else:
107
+ kl = -0.5 * torch.sum(1 + q_logsigma - q_mu.pow(2) - q_logsigma.exp(), dim=-1)
108
+ return kl
109
+
110
+ def get_alpha(self): ## mean field
111
+ alphas = torch.zeros(self.num_times, self.num_topics, self.rho_size).to(self.device)
112
+ kl_alpha = []
113
+
114
+ alphas[0] = self.reparameterize(self.mu_q_alpha[:, 0, :], self.logsigma_q_alpha[:, 0, :])
115
+
116
+ # TODO: why logsigma_p_0 is zero?
117
+ p_mu_0 = torch.zeros(self.num_topics, self.rho_size).to(self.device)
118
+ logsigma_p_0 = torch.zeros(self.num_topics, self.rho_size).to(self.device)
119
+ kl_0 = self.get_kl(self.mu_q_alpha[:, 0, :], self.logsigma_q_alpha[:, 0, :], p_mu_0, logsigma_p_0)
120
+ kl_alpha.append(kl_0)
121
+ for t in range(1, self.num_times):
122
+ alphas[t] = self.reparameterize(self.mu_q_alpha[:, t, :], self.logsigma_q_alpha[:, t, :])
123
+
124
+ p_mu_t = alphas[t - 1]
125
+ logsigma_p_t = torch.log(self.delta * torch.ones(self.num_topics, self.rho_size).to(self.device))
126
+ kl_t = self.get_kl(self.mu_q_alpha[:, t, :], self.logsigma_q_alpha[:, t, :], p_mu_t, logsigma_p_t)
127
+ kl_alpha.append(kl_t)
128
+ kl_alpha = torch.stack(kl_alpha).sum()
129
+ return alphas, kl_alpha.sum()
130
+
131
+ def get_eta(self, rnn_inp): ## structured amortized inference
132
+ inp = self.q_eta_map(rnn_inp).unsqueeze(1)
133
+ hidden = self.init_hidden()
134
+ output, _ = self.q_eta(inp, hidden)
135
+ output = output.squeeze()
136
+
137
+ etas = torch.zeros(self.num_times, self.num_topics).to(self.device)
138
+ kl_eta = []
139
+
140
+ inp_0 = torch.cat([output[0], torch.zeros(self.num_topics,).to(self.device)], dim=0)
141
+ mu_0 = self.mu_q_eta(inp_0)
142
+ logsigma_0 = self.logsigma_q_eta(inp_0)
143
+ etas[0] = self.reparameterize(mu_0, logsigma_0)
144
+
145
+ p_mu_0 = torch.zeros(self.num_topics,).to(self.device)
146
+ logsigma_p_0 = torch.zeros(self.num_topics,).to(self.device)
147
+ kl_0 = self.get_kl(mu_0, logsigma_0, p_mu_0, logsigma_p_0)
148
+ kl_eta.append(kl_0)
149
+
150
+ for t in range(1, self.num_times):
151
+ inp_t = torch.cat([output[t], etas[t-1]], dim=0)
152
+ mu_t = self.mu_q_eta(inp_t)
153
+ logsigma_t = self.logsigma_q_eta(inp_t)
154
+ etas[t] = self.reparameterize(mu_t, logsigma_t)
155
+
156
+ p_mu_t = etas[t-1]
157
+ logsigma_p_t = torch.log(self.delta * torch.ones(self.num_topics,).to(self.device))
158
+ kl_t = self.get_kl(mu_t, logsigma_t, p_mu_t, logsigma_p_t)
159
+ kl_eta.append(kl_t)
160
+ kl_eta = torch.stack(kl_eta).sum()
161
+
162
+ return etas, kl_eta
163
+
164
+ def get_theta(self, bows, times, eta=None): ## amortized inference
165
+ """Returns the topic proportions.
166
+ """
167
+
168
+ normalized_bows = bows / bows.sum(1, keepdims=True)
169
+
170
+ if eta is None and self.training is False:
171
+ eta, kl_eta = self.get_eta(self.rnn_inp)
172
+
173
+ eta_td = eta[times]
174
+ inp = torch.cat([normalized_bows, eta_td], dim=1)
175
+ q_theta = self.q_theta(inp)
176
+ if self.enc_drop > 0:
177
+ q_theta = self.t_drop(q_theta)
178
+ mu_theta = self.mu_q_theta(q_theta)
179
+ logsigma_theta = self.logsigma_q_theta(q_theta)
180
+ z = self.reparameterize(mu_theta, logsigma_theta)
181
+ theta = F.softmax(z, dim=-1)
182
+ kl_theta = self.get_kl(mu_theta, logsigma_theta, eta_td, torch.zeros(self.num_topics).to(self.device))
183
+
184
+ if self.training:
185
+ return theta, kl_theta
186
+ else:
187
+ return theta
188
+
189
+ @property
190
+ def word_embeddings(self):
191
+ return self.rho.weight
192
+
193
+ @property
194
+ def topic_embeddings(self):
195
+ alpha, _ = self.get_alpha()
196
+ return alpha
197
+
198
+ def get_beta(self, alpha=None):
199
+ """Returns the topic matrix \beta of shape T x K x V
200
+ """
201
+
202
+ if alpha is None and self.training is False:
203
+ alpha, kl_alpha = self.get_alpha()
204
+
205
+ if self.train_WE:
206
+ logit = self.rho(alpha.view(alpha.size(0) * alpha.size(1), self.rho_size))
207
+ else:
208
+ tmp = alpha.view(alpha.size(0) * alpha.size(1), self.rho_size)
209
+ logit = torch.mm(tmp, self.rho.permute(1, 0))
210
+ logit = logit.view(alpha.size(0), alpha.size(1), -1)
211
+
212
+ beta = F.softmax(logit, dim=-1)
213
+
214
+ return beta
215
+
216
+ def get_NLL(self, theta, beta, bows):
217
+ theta = theta.unsqueeze(1)
218
+ loglik = torch.bmm(theta, beta).squeeze(1)
219
+ loglik = torch.log(loglik + 1e-12)
220
+ nll = -loglik * bows
221
+ nll = nll.sum(-1)
222
+ return nll
223
+
224
+ def forward(self, bows, times):
225
+ bsz = bows.size(0)
226
+ coeff = self.train_size / bsz
227
+ eta, kl_eta = self.get_eta(self.rnn_inp)
228
+ theta, kl_theta = self.get_theta(bows, times, eta)
229
+ kl_theta = kl_theta.sum() * coeff
230
+
231
+ alpha, kl_alpha = self.get_alpha()
232
+ beta = self.get_beta(alpha)
233
+
234
+ beta = beta[times]
235
+ # beta = beta[times.type('torch.LongTensor')]
236
+ nll = self.get_NLL(theta, beta, bows)
237
+ nll = nll.sum() * coeff
238
+
239
+ loss = nll + kl_eta + kl_theta
240
+
241
+ rst_dict = {
242
+ 'loss': loss,
243
+ 'nll': nll,
244
+ 'kl_eta': kl_eta,
245
+ 'kl_theta': kl_theta
246
+ }
247
+
248
+ loss += kl_alpha
249
+ rst_dict['kl_alpha'] = kl_alpha
250
+
251
+ return rst_dict
252
+
253
+ def init_hidden(self):
254
+ """Initializes the first hidden state of the RNN used as inference network for \\eta.
255
+ """
256
+ weight = next(self.parameters())
257
+ nlayers = self.eta_nlayers
258
+ nhid = self.eta_hidden_size
259
+ return (weight.new_zeros(nlayers, 1, nhid), weight.new_zeros(nlayers, 1, nhid))
backend/models/DTM_trainer.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gensim
2
+ import numpy as np
3
+ from gensim.models import ldaseqmodel
4
+ from tqdm import tqdm
5
+ import datetime
6
+ from multiprocessing.pool import Pool
7
+ from backend.datasets.utils import _utils
8
+ from backend.datasets.utils.logger import Logger
9
+
10
+
11
+ logger = Logger("WARNING")
12
+
13
+
14
+ def work(arguments):
15
+ model, docs = arguments
16
+ theta_list = list()
17
+ for doc in tqdm(docs):
18
+ theta_list.append(model[doc])
19
+ return theta_list
20
+
21
+
22
+ class DTMTrainer:
23
+ def __init__(self,
24
+ dataset,
25
+ num_topics=50,
26
+ num_top_words=15,
27
+ alphas=0.01,
28
+ chain_variance=0.005,
29
+ passes=10,
30
+ lda_inference_max_iter=25,
31
+ em_min_iter=6,
32
+ em_max_iter=20,
33
+ verbose=False
34
+ ):
35
+
36
+ self.dataset = dataset
37
+ self.vocab_size = dataset.vocab_size
38
+ self.num_topics = num_topics
39
+ self.num_top_words = num_top_words
40
+ self.alphas = alphas
41
+ self.chain_variance = chain_variance
42
+ self.passes = passes
43
+ self.lda_inference_max_iter = lda_inference_max_iter
44
+ self.em_min_iter = em_min_iter
45
+ self.em_max_iter = em_max_iter
46
+
47
+ self.verbose = verbose
48
+ if verbose:
49
+ logger.set_level("DEBUG")
50
+ else:
51
+ logger.set_level("WARNING")
52
+
53
+ def train(self):
54
+ id2word = dict(zip(range(self.vocab_size), self.dataset.vocab))
55
+ train_bow = self.dataset.train_bow
56
+ train_times = self.dataset.train_times.astype('int32')
57
+
58
+ # order documents by time slices
59
+ self.doc_order_idx = np.argsort(train_times)
60
+ train_bow = train_bow[self.doc_order_idx]
61
+ time_slices = np.bincount(train_times)
62
+
63
+ corpus = gensim.matutils.Dense2Corpus(train_bow, documents_columns=False)
64
+
65
+ self.model = ldaseqmodel.LdaSeqModel(
66
+ corpus=corpus,
67
+ id2word=id2word,
68
+ time_slice=time_slices,
69
+ num_topics=self.num_topics,
70
+ alphas=self.alphas,
71
+ chain_variance=self.chain_variance,
72
+ em_min_iter=self.em_min_iter,
73
+ em_max_iter=self.em_max_iter,
74
+ lda_inference_max_iter=self.lda_inference_max_iter,
75
+ passes=self.passes
76
+ )
77
+
78
+ def test(self, bow):
79
+ # bow = dataset.bow.cpu().numpy()
80
+ # times = dataset.times.cpu().numpy()
81
+ corpus = gensim.matutils.Dense2Corpus(bow, documents_columns=False)
82
+
83
+ num_workers = 20
84
+ split_idx_list = np.array_split(np.arange(len(bow)), num_workers)
85
+ worker_size_list = [len(x) for x in split_idx_list]
86
+
87
+ worker_id = 0
88
+ docs_list = [list() for i in range(num_workers)]
89
+ for i, doc in enumerate(corpus):
90
+ docs_list[worker_id].append(doc)
91
+ if len(docs_list[worker_id]) >= worker_size_list[worker_id]:
92
+ worker_id += 1
93
+
94
+ args_list = list()
95
+ for docs in docs_list:
96
+ args_list.append([self.model, docs])
97
+
98
+ starttime = datetime.datetime.now()
99
+
100
+ pool = Pool(processes=num_workers)
101
+ results = pool.map(work, args_list)
102
+
103
+ pool.close()
104
+ pool.join()
105
+
106
+ theta_list = list()
107
+ for rst in results:
108
+ theta_list.extend(rst)
109
+
110
+ endtime = datetime.datetime.now()
111
+
112
+ print("DTM test time: {}s".format((endtime - starttime).seconds))
113
+
114
+ return np.asarray(theta_list)
115
+
116
+ def get_theta(self):
117
+ theta = self.model.gammas / self.model.gammas.sum(axis=1)[:, np.newaxis]
118
+ # NOTE: MUST transform gamma to original order.
119
+ return theta[np.argsort(self.doc_order_idx)]
120
+
121
+ def get_beta(self):
122
+ beta = list()
123
+ # K x V x T
124
+ for item in self.model.topic_chains:
125
+ # V x T
126
+ beta.append(item.e_log_prob)
127
+
128
+ # T x K x V
129
+ beta = np.transpose(np.asarray(beta), (2, 0, 1))
130
+ # use softmax
131
+ beta = np.exp(beta)
132
+ beta = beta / beta.sum(-1, keepdims=True)
133
+ return beta
134
+
135
+ def get_top_words(self, num_top_words=None):
136
+ if num_top_words is None:
137
+ num_top_words = self.num_top_words
138
+ beta = self.get_beta()
139
+ top_words_list = list()
140
+ for time in range(beta.shape[0]):
141
+ top_words = _utils.get_top_words(beta[time], self.dataset.vocab, num_top_words, self.verbose)
142
+ top_words_list.append(top_words)
143
+ return top_words_list
144
+
145
+ def export_theta(self):
146
+ train_theta = self.get_theta()
147
+ test_theta = self.test(self.dataset.test_bow)
148
+ return train_theta, test_theta
backend/models/dynamic_trainer.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from tqdm import tqdm
3
+ from collections import defaultdict
4
+
5
+ import torch
6
+ from torch.optim.lr_scheduler import StepLR
7
+ from backend.datasets.utils import _utils
8
+ from backend.datasets.utils.logger import Logger
9
+
10
+ logger = Logger("WARNING")
11
+
12
+ class DynamicTrainer:
13
+ def __init__(self,
14
+ model,
15
+ dataset,
16
+ num_top_words=15,
17
+ epochs=200,
18
+ learning_rate=0.002,
19
+ batch_size=200,
20
+ lr_scheduler=None,
21
+ lr_step_size=125,
22
+ log_interval=5,
23
+ verbose=False
24
+ ):
25
+
26
+ self.model = model
27
+ self.dataset = dataset
28
+ self.num_top_words = num_top_words
29
+ self.epochs = epochs
30
+ self.learning_rate = learning_rate
31
+ self.batch_size = batch_size
32
+ self.lr_scheduler = lr_scheduler
33
+ self.lr_step_size = lr_step_size
34
+ self.log_interval = log_interval
35
+
36
+ self.verbose = verbose
37
+ if verbose:
38
+ logger.set_level("DEBUG")
39
+ else:
40
+ logger.set_level("WARNING")
41
+
42
+ def make_optimizer(self,):
43
+ args_dict = {
44
+ 'params': self.model.parameters(),
45
+ 'lr': self.learning_rate,
46
+ }
47
+
48
+ optimizer = torch.optim.Adam(**args_dict)
49
+ return optimizer
50
+
51
+ def make_lr_scheduler(self, optimizer):
52
+ lr_scheduler = StepLR(optimizer, step_size=self.lr_step_size, gamma=0.5, verbose=False)
53
+ return lr_scheduler
54
+
55
+ def train(self):
56
+ optimizer = self.make_optimizer()
57
+
58
+ if self.lr_scheduler:
59
+ logger.info("using lr_scheduler")
60
+ lr_scheduler = self.make_lr_scheduler(optimizer)
61
+
62
+ data_size = len(self.dataset.train_dataloader.dataset)
63
+
64
+ for epoch in tqdm(range(1, self.epochs + 1)):
65
+ self.model.train()
66
+ loss_rst_dict = defaultdict(float)
67
+
68
+ for batch_data in self.dataset.train_dataloader:
69
+
70
+ rst_dict = self.model(batch_data['bow'], batch_data['times'])
71
+ batch_loss = rst_dict['loss']
72
+
73
+ optimizer.zero_grad()
74
+ batch_loss.backward()
75
+ optimizer.step()
76
+
77
+ for key in rst_dict:
78
+ loss_rst_dict[key] += rst_dict[key] * len(batch_data)
79
+
80
+ if self.lr_scheduler:
81
+ lr_scheduler.step()
82
+
83
+ if epoch % self.log_interval == 0:
84
+ output_log = f'Epoch: {epoch:03d}'
85
+ for key in loss_rst_dict:
86
+ output_log += f' {key}: {loss_rst_dict[key] / data_size :.3f}'
87
+
88
+ logger.info(output_log)
89
+
90
+ top_words = self.get_top_words()
91
+ train_theta = self.test(self.dataset.train_bow, self.dataset.train_times)
92
+
93
+ return top_words, train_theta
94
+
95
+ def test(self, bow, times):
96
+ data_size = bow.shape[0]
97
+ theta = list()
98
+ all_idx = torch.split(torch.arange(data_size), self.batch_size)
99
+
100
+ with torch.no_grad():
101
+ self.model.eval()
102
+ for idx in all_idx:
103
+ batch_theta = self.model.get_theta(bow[idx], times[idx])
104
+ theta.extend(batch_theta.cpu().tolist())
105
+
106
+ theta = np.asarray(theta)
107
+ return theta
108
+
109
+ def get_beta(self):
110
+ self.model.eval()
111
+ beta = self.model.get_beta().detach().cpu().numpy()
112
+ return beta
113
+
114
+ def get_top_words(self, num_top_words=None):
115
+ if num_top_words is None:
116
+ num_top_words = self.num_top_words
117
+
118
+ beta = self.get_beta()
119
+ top_words_list = list()
120
+ for time in range(beta.shape[0]):
121
+ if self.verbose:
122
+ print(f"======= Time: {time} =======")
123
+ top_words = _utils.get_top_words(beta[time], self.dataset.vocab, num_top_words, self.verbose)
124
+ top_words_list.append(top_words)
125
+ return top_words_list
126
+
127
+ def export_theta(self):
128
+ train_theta = self.test(self.dataset.train_bow, self.dataset.train_times)
129
+ test_theta = self.test(self.dataset.test_bow, self.dataset.test_times)
130
+
131
+ return train_theta, test_theta
132
+
133
+ def get_top_words_at_time(self, topic_id, time, top_n):
134
+ beta = self.get_beta() # shape: [T, K, V]
135
+ topic_beta = beta[time, topic_id, :]
136
+ top_indices = topic_beta.argsort()[-top_n:][::-1]
137
+ return [self.dataset.vocab[i] for i in top_indices]
138
+
139
+
140
+ def get_topic_words_over_time(self, topic_id, top_n):
141
+ """
142
+ Returns top_n words for the given topic_id over all time steps.
143
+ Output: List[List[str]], each inner list is the top_n words at a time step.
144
+ """
145
+ beta = self.get_beta() # shape: [T, K, V]
146
+ T = beta.shape[0]
147
+ return [
148
+ self.get_top_words_at_time(topic_id=topic_id, time=t, top_n=top_n)
149
+ for t in range(T)
150
+ ]
151
+
152
+ def get_all_topics_at_time(self, time, top_n):
153
+ """
154
+ Returns top_n words for each topic at the given time step.
155
+ Output: List[List[str]], each inner list is the top_n words for a topic.
156
+ """
157
+ beta = self.get_beta() # shape: [T, K, V]
158
+ K = beta.shape[1]
159
+ return [
160
+ self.get_top_words_at_time(topic_id=k, time=time, top_n=top_n)
161
+ for k in range(K)
162
+ ]
163
+
164
+ def get_all_topics_over_time(self, top_n=10):
165
+ """
166
+ Returns the top_n words for all topics over all time steps.
167
+ Output shape: List[List[List[str]]] = T x K x top_n
168
+ """
169
+ beta = self.get_beta() # shape: [T, K, V]
170
+ T, K, _ = beta.shape
171
+ return [
172
+ [
173
+ self.get_top_words_at_time(topic_id=k, time=t, top_n=top_n)
174
+ for k in range(K)
175
+ ]
176
+ for t in range(T)
177
+ ]
data/ACL_Anthology/CFDTM/beta.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:34984bfb432a10733161a9dfed834a9ef4f366a28a6cb2ecd6e8351997f1599a
3
+ size 16645248
data/ACL_Anthology/DETM/beta.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3c6eefa9b6aaea4c694736d09ad9e517446f09929c01889e26633300e5eff166
3
+ size 41612928
data/ACL_Anthology/DTM/beta.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:14c296a2e3fb49f9d0b66262907d64f7d181408768e43138d57c262ea6a11318
3
+ size 33290368
data/ACL_Anthology/DTM/topic_label_cache.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea9f3c508ede82967cdf02050d7383d58dd9d269a7f661ae1462a95cbac3331e
3
+ size 2089
data/ACL_Anthology/docs.jsonl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a004dd095b9a4f29fdccb5144d50d3dacc7985af443a8de434005b7b8401f9b7
3
+ size 67395059
data/ACL_Anthology/inverted_index.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:60e7ee888abb2fd025b11415a7ead6780d41c5f890cc25ba453615906f10b8d7
3
+ size 30865281
data/ACL_Anthology/processed/lemma_to_forms.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:00ea8855f9ced2ca3d785ce5926ced29b35e0779cd6b3166edfd5c5a5c1beccb
3
+ size 4370995
data/ACL_Anthology/processed/length_stats.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5cc985e5a1ce565ca4179d343ade1526daab463520f6317122953da83d368306
3
+ size 133
data/ACL_Anthology/processed/time2id.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "2010": 0,
3
+ "2011": 1,
4
+ "2012": 2,
5
+ "2013": 3,
6
+ "2014": 4,
7
+ "2015": 5,
8
+ "2016": 6,
9
+ "2017": 7,
10
+ "2018": 8,
11
+ "2019": 9,
12
+ "2020": 10,
13
+ "2021": 11,
14
+ "2022": 12,
15
+ "2023": 13,
16
+ "2024": 14,
17
+ "2025": 15
18
+ }
data/ACL_Anthology/processed/vocab.txt ADDED
The diff for this file is too large to render. See raw diff