Spaces:
Sleeping
Sleeping
Merge master into main, resolved conflicts and updated LFS tracking
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +5 -0
- .huggingface.yaml +3 -0
- LICENSE +21 -0
- app/ui_updated.py +450 -0
- assets/Logo_light.png +3 -0
- backend/__init__.py +81 -0
- backend/datasets/_preprocess.py +447 -0
- backend/datasets/data/download.py +32 -0
- backend/datasets/data/file_utils.py +39 -0
- backend/datasets/dynamic_dataset.py +90 -0
- backend/datasets/preprocess.py +362 -0
- backend/datasets/utils/_utils.py +37 -0
- backend/datasets/utils/logger.py +29 -0
- backend/evaluation/CoherenceModel_ttc.py +862 -0
- backend/evaluation/eval.py +179 -0
- backend/inference/doc_retriever.py +219 -0
- backend/inference/indexing_utils.py +146 -0
- backend/inference/peak_detector.py +18 -0
- backend/inference/process_beta.py +33 -0
- backend/inference/word_selector.py +102 -0
- backend/llm/custom_gemini.py +28 -0
- backend/llm/custom_mistral.py +27 -0
- backend/llm/llm_router.py +73 -0
- backend/llm_utils/label_generator.py +72 -0
- backend/llm_utils/summarizer.py +192 -0
- backend/llm_utils/token_utils.py +167 -0
- backend/models/CFDTM/CFDTM.py +127 -0
- backend/models/CFDTM/ETC.py +62 -0
- backend/models/CFDTM/Encoder.py +40 -0
- backend/models/CFDTM/UWE.py +48 -0
- backend/models/CFDTM/__init__.py +0 -0
- backend/models/CFDTM/__pycache__/CFDTM.cpython-39.pyc +0 -0
- backend/models/CFDTM/__pycache__/ETC.cpython-39.pyc +0 -0
- backend/models/CFDTM/__pycache__/Encoder.cpython-39.pyc +0 -0
- backend/models/CFDTM/__pycache__/UWE.cpython-39.pyc +0 -0
- backend/models/CFDTM/__pycache__/__init__.cpython-39.pyc +0 -0
- backend/models/DBERTopic_trainer.py +99 -0
- backend/models/DETM.py +259 -0
- backend/models/DTM_trainer.py +148 -0
- backend/models/dynamic_trainer.py +177 -0
- data/ACL_Anthology/CFDTM/beta.npy +3 -0
- data/ACL_Anthology/DETM/beta.npy +3 -0
- data/ACL_Anthology/DTM/beta.npy +3 -0
- data/ACL_Anthology/DTM/topic_label_cache.json +3 -0
- data/ACL_Anthology/docs.jsonl +3 -0
- data/ACL_Anthology/inverted_index.json +3 -0
- data/ACL_Anthology/processed/lemma_to_forms.json +3 -0
- data/ACL_Anthology/processed/length_stats.json +3 -0
- data/ACL_Anthology/processed/time2id.txt +18 -0
- data/ACL_Anthology/processed/vocab.txt +0 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
data/**/*.npy filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
data/**/*.jsonl filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
data/**/*.json filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
assets/*.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
data/**/*.npz filter=lfs diff=lfs merge=lfs -text
|
.huggingface.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# .huggingface.yaml
|
| 2 |
+
sdk: streamlit # or gradio
|
| 3 |
+
app_file: ./app/ui.py
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 Suman Adhya
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
app/ui_updated.py
ADDED
|
@@ -0,0 +1,450 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import plotly.graph_objects as go
|
| 3 |
+
import plotly.colors as pc
|
| 4 |
+
import sys
|
| 5 |
+
import os
|
| 6 |
+
import base64
|
| 7 |
+
import streamlit.components.v1 as components
|
| 8 |
+
import html
|
| 9 |
+
|
| 10 |
+
# Absolute path to the repo root (assuming `ui.py` is in /app)
|
| 11 |
+
REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
| 12 |
+
sys.path.append(REPO_ROOT)
|
| 13 |
+
ASSETS_DIR = os.path.join(REPO_ROOT, 'assets')
|
| 14 |
+
DATA_DIR = os.path.join(REPO_ROOT, 'data')
|
| 15 |
+
|
| 16 |
+
# Import functions from the backend
|
| 17 |
+
from backend.inference.process_beta import (
|
| 18 |
+
load_beta_matrix,
|
| 19 |
+
get_top_words_over_time,
|
| 20 |
+
load_time_labels
|
| 21 |
+
)
|
| 22 |
+
from backend.inference.word_selector import get_interesting_words, get_word_trend
|
| 23 |
+
from backend.inference.indexing_utils import load_index
|
| 24 |
+
from backend.inference.doc_retriever import (
|
| 25 |
+
load_length_stats,
|
| 26 |
+
get_yearly_counts_for_word,
|
| 27 |
+
deduplicate_docs,
|
| 28 |
+
get_all_documents_for_word_year,
|
| 29 |
+
highlight_words,
|
| 30 |
+
extract_snippet
|
| 31 |
+
)
|
| 32 |
+
from backend.llm_utils.summarizer import summarize_multiword_docs, ask_multiturn_followup
|
| 33 |
+
from backend.llm_utils.label_generator import get_topic_labels
|
| 34 |
+
from backend.llm.llm_router import get_llm, list_supported_models
|
| 35 |
+
from backend.llm_utils.token_utils import estimate_k_max_from_word_stats
|
| 36 |
+
|
| 37 |
+
def get_base64_image(image_path):
|
| 38 |
+
with open(image_path, "rb") as img_file:
|
| 39 |
+
return base64.b64encode(img_file.read()).decode()
|
| 40 |
+
|
| 41 |
+
# --- Page Configuration ---
|
| 42 |
+
st.set_page_config(
|
| 43 |
+
page_title="DTECT",
|
| 44 |
+
page_icon="🔍",
|
| 45 |
+
layout="wide"
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
# Sidebar branding and repo link
|
| 49 |
+
st.sidebar.markdown(
|
| 50 |
+
"""
|
| 51 |
+
<div style="text-align: center;">
|
| 52 |
+
<a href="https://github.com/dinb-ai/DTECT" target="_blank">
|
| 53 |
+
<img src="data:image/png;base64,{}" width="180" style="margin-bottom: 18px;">
|
| 54 |
+
</a>
|
| 55 |
+
<hr style="margin-bottom: 0;">
|
| 56 |
+
</div>
|
| 57 |
+
""".format(get_base64_image(os.path.join(ASSETS_DIR, 'Logo_light.png'))),
|
| 58 |
+
unsafe_allow_html=True
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
# 1. Sidebar: Model and Dataset Selection
|
| 62 |
+
st.sidebar.title("Configuration")
|
| 63 |
+
|
| 64 |
+
AVAILABLE_MODELS = ["DTM", "DETM", "CFDTM"]
|
| 65 |
+
ENV_VAR_MAP = {
|
| 66 |
+
"OpenAI": "OPENAI_API_KEY",
|
| 67 |
+
"Anthropic": "ANTHROPIC_API_KEY",
|
| 68 |
+
"Gemini": "GEMINI_API_KEY",
|
| 69 |
+
"Mistral": "MISTRAL_API_KEY"
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
def list_datasets(data_dir):
|
| 73 |
+
return sorted([
|
| 74 |
+
name for name in os.listdir(data_dir)
|
| 75 |
+
if os.path.isdir(os.path.join(data_dir, name))
|
| 76 |
+
])
|
| 77 |
+
|
| 78 |
+
with st.sidebar.expander("Select Dataset & Topic Model", expanded=True):
|
| 79 |
+
datasets = list_datasets(DATA_DIR)
|
| 80 |
+
selected_dataset = st.selectbox("Dataset", datasets, help="Choose an available dataset.")
|
| 81 |
+
selected_model = st.selectbox("Model", AVAILABLE_MODELS, help="Select topic model architecture.")
|
| 82 |
+
|
| 83 |
+
# Resolve paths
|
| 84 |
+
dataset_path = os.path.join(DATA_DIR, selected_dataset)
|
| 85 |
+
model_path = os.path.join(dataset_path, selected_model)
|
| 86 |
+
docs_path = os.path.join(dataset_path, "docs.jsonl")
|
| 87 |
+
vocab_path = os.path.join(dataset_path, "processed/vocab.txt")
|
| 88 |
+
time2id_path = os.path.join(dataset_path, "processed/time2id.txt")
|
| 89 |
+
index_path = os.path.join(dataset_path, "inverted_index.json")
|
| 90 |
+
beta_path = os.path.join(model_path, "beta.npy")
|
| 91 |
+
label_cache_path = os.path.join(model_path, "topic_label_cache.json")
|
| 92 |
+
length_stats_path = os.path.join(dataset_path, "processed/length_stats.json")
|
| 93 |
+
lemma_map_path = os.path.join(dataset_path, "processed/lemma_to_forms.json")
|
| 94 |
+
|
| 95 |
+
with st.sidebar.expander("LLM Settings", expanded=True):
|
| 96 |
+
provider = st.selectbox("LLM Provider", options=list(ENV_VAR_MAP.keys()), help="Choose the LLM backend.")
|
| 97 |
+
available_models = list_supported_models(provider)
|
| 98 |
+
model = st.selectbox("LLM Model", options=available_models)
|
| 99 |
+
env_var = ENV_VAR_MAP[provider]
|
| 100 |
+
api_key = os.getenv(env_var)
|
| 101 |
+
|
| 102 |
+
if "llm_configured" not in st.session_state:
|
| 103 |
+
st.session_state.llm_configured = False
|
| 104 |
+
|
| 105 |
+
if api_key:
|
| 106 |
+
st.session_state.llm_configured = True
|
| 107 |
+
else:
|
| 108 |
+
st.session_state.llm_configured = False
|
| 109 |
+
with st.form(key="api_key_form"):
|
| 110 |
+
entered_key = st.text_input(f"Enter your {provider} API Key", type="password")
|
| 111 |
+
submitted = st.form_submit_button("Submit and Confirm")
|
| 112 |
+
if submitted:
|
| 113 |
+
if entered_key:
|
| 114 |
+
os.environ[env_var] = entered_key
|
| 115 |
+
api_key = entered_key
|
| 116 |
+
st.session_state.llm_configured = True
|
| 117 |
+
st.rerun()
|
| 118 |
+
else:
|
| 119 |
+
st.warning("Please enter a key.")
|
| 120 |
+
|
| 121 |
+
if not st.session_state.llm_configured:
|
| 122 |
+
st.warning("Please configure your LLM settings in the sidebar.")
|
| 123 |
+
st.stop()
|
| 124 |
+
|
| 125 |
+
if api_key and not st.session_state.llm_configured:
|
| 126 |
+
st.session_state.llm_configured = True
|
| 127 |
+
|
| 128 |
+
if not api_key:
|
| 129 |
+
st.session_state.llm_configured = False
|
| 130 |
+
|
| 131 |
+
if not st.session_state.llm_configured:
|
| 132 |
+
st.warning("Please configure your LLM settings in the sidebar.")
|
| 133 |
+
st.stop()
|
| 134 |
+
|
| 135 |
+
# Initialize LLM with the provided key
|
| 136 |
+
llm = get_llm(provider=provider, model=model, api_key=api_key)
|
| 137 |
+
|
| 138 |
+
# 3. Load Data
|
| 139 |
+
@st.cache_resource
|
| 140 |
+
def load_resources(beta_path, vocab_path, docs_path, index_path, time2id_path, length_stats_path, lemma_map_path):
|
| 141 |
+
beta, vocab = load_beta_matrix(beta_path, vocab_path)
|
| 142 |
+
index, docs, lemma_to_forms = load_index(docs_file_path=docs_path, vocab=vocab, index_path=index_path, lemma_map_path=lemma_map_path)
|
| 143 |
+
time_labels = load_time_labels(time2id_path)
|
| 144 |
+
length_stats = load_length_stats(length_stats_path)
|
| 145 |
+
return beta, vocab, index, docs, lemma_to_forms, time_labels, length_stats
|
| 146 |
+
|
| 147 |
+
# --- Main Title and Paper-aligned Intro ---
|
| 148 |
+
st.markdown("""# 🔍 DTECT: Dynamic Topic Explorer & Context Tracker""")
|
| 149 |
+
|
| 150 |
+
# --- Load resources ---
|
| 151 |
+
try:
|
| 152 |
+
beta, vocab, index, docs, lemma_to_forms, time_labels, length_stats = load_resources(
|
| 153 |
+
beta_path,
|
| 154 |
+
vocab_path,
|
| 155 |
+
docs_path,
|
| 156 |
+
index_path,
|
| 157 |
+
time2id_path,
|
| 158 |
+
length_stats_path,
|
| 159 |
+
lemma_map_path
|
| 160 |
+
)
|
| 161 |
+
except FileNotFoundError as e:
|
| 162 |
+
st.error(f"Missing required file: {e}")
|
| 163 |
+
st.stop()
|
| 164 |
+
except Exception as e:
|
| 165 |
+
st.error(f"Failed to load data: {str(e)}")
|
| 166 |
+
st.stop()
|
| 167 |
+
|
| 168 |
+
timestamps = list(range(len(time_labels)))
|
| 169 |
+
num_topics = beta.shape[1]
|
| 170 |
+
# Estimate max_k based on document length stats and selected LLM
|
| 171 |
+
suggested_max_k = estimate_k_max_from_word_stats(length_stats.get("avg_len"), model_name=model, provider=provider)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
# ==============================================================================
|
| 175 |
+
# 1. 🏷 TOPIC LABELING
|
| 176 |
+
# ==============================================================================
|
| 177 |
+
st.markdown("## 1️⃣ 🏷️ Topic Labeling")
|
| 178 |
+
st.info("Topics are automatically labeled using LLMs by analyzing their temporal word distributions.")
|
| 179 |
+
|
| 180 |
+
topic_labels = get_topic_labels(beta, vocab, time_labels, llm, label_cache_path)
|
| 181 |
+
topic_options = list(topic_labels.values())
|
| 182 |
+
selected_topic_label = st.selectbox("Select a Topic", topic_options, help="LLM-generated topic label")
|
| 183 |
+
label_to_topic = {v: k for k, v in topic_labels.items()}
|
| 184 |
+
selected_topic = label_to_topic[selected_topic_label]
|
| 185 |
+
|
| 186 |
+
# ==============================================================================
|
| 187 |
+
# 2. 💡 INFORMATIVE WORD DETECTION & 📊 TREND VISUALIZATION
|
| 188 |
+
# ==============================================================================
|
| 189 |
+
st.markdown("---")
|
| 190 |
+
st.markdown("## 2️⃣ 💡 Informative Word Detection & 📊 Trend Visualization")
|
| 191 |
+
st.info("Explore top/interesting words for each topic, and visualize their trends over time.")
|
| 192 |
+
|
| 193 |
+
top_n_words = st.slider("Number of Top Words per Topic", min_value=5, max_value=500, value=10)
|
| 194 |
+
top_words = get_top_words_over_time(
|
| 195 |
+
beta=beta,
|
| 196 |
+
vocab=vocab,
|
| 197 |
+
topic_id=selected_topic,
|
| 198 |
+
top_n=top_n_words
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
st.write(f"### Top {top_n_words} Words for Topic '{selected_topic_label}' (Ranked):")
|
| 202 |
+
scrollable_top_words = "<div style='max-height: 200px; overflow-y: auto; padding: 0 10px;'>"
|
| 203 |
+
words_per_col = (top_n_words + 3) // 4
|
| 204 |
+
columns = [top_words[i:i+words_per_col] for i in range(0, len(top_words), words_per_col)]
|
| 205 |
+
scrollable_top_words += "<div style='display: flex; gap: 20px;'>"
|
| 206 |
+
word_rank = 1
|
| 207 |
+
for col in columns:
|
| 208 |
+
scrollable_top_words += "<div style='flex: 1;'>"
|
| 209 |
+
for word in col:
|
| 210 |
+
scrollable_top_words += f"<div style='margin-bottom: 4px;'>{word_rank}. {word}</div>"
|
| 211 |
+
word_rank += 1
|
| 212 |
+
scrollable_top_words += "</div>"
|
| 213 |
+
scrollable_top_words += "</div></div>"
|
| 214 |
+
st.markdown(scrollable_top_words, unsafe_allow_html=True)
|
| 215 |
+
|
| 216 |
+
st.markdown("<div style='margin-top: 18px;'></div>", unsafe_allow_html=True)
|
| 217 |
+
|
| 218 |
+
if st.button("💡 Suggest Informative Words", key="suggest_topic_words"):
|
| 219 |
+
top_words = get_top_words_over_time(
|
| 220 |
+
beta=beta,
|
| 221 |
+
vocab=vocab,
|
| 222 |
+
topic_id=selected_topic,
|
| 223 |
+
top_n=top_n_words
|
| 224 |
+
)
|
| 225 |
+
interesting_words = get_interesting_words(beta, vocab, topic_id=selected_topic, restrict_to=top_words)
|
| 226 |
+
st.session_state.interesting_words = interesting_words
|
| 227 |
+
st.session_state.selected_words = interesting_words[:15] # pre-fill multiselect
|
| 228 |
+
styled_words = " ".join([
|
| 229 |
+
f"<span style='background-color:#e0f7fa; color:#004d40; font-weight:500; padding:4px 8px; margin:4px; border-radius:8px; display:inline-block;'>{w}</span>"
|
| 230 |
+
for w in interesting_words
|
| 231 |
+
])
|
| 232 |
+
st.markdown(
|
| 233 |
+
f"**Top Informative Words from Topic '{selected_topic_label}':**<br>{styled_words}",
|
| 234 |
+
unsafe_allow_html=True
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
st.markdown("#### 📈 Plot Word Trends Over Time")
|
| 238 |
+
all_word_options = vocab
|
| 239 |
+
interesting_words = st.session_state.get("interesting_words", [])
|
| 240 |
+
|
| 241 |
+
if "selected_words" not in st.session_state:
|
| 242 |
+
st.session_state.selected_words = interesting_words[:15] # initial default
|
| 243 |
+
|
| 244 |
+
selected_words = st.multiselect(
|
| 245 |
+
"Select words to visualize trends",
|
| 246 |
+
options=all_word_options,
|
| 247 |
+
default=st.session_state.selected_words,
|
| 248 |
+
key="selected_words"
|
| 249 |
+
)
|
| 250 |
+
if selected_words:
|
| 251 |
+
fig = go.Figure()
|
| 252 |
+
color_cycle = pc.qualitative.Plotly
|
| 253 |
+
for i, word in enumerate(selected_words):
|
| 254 |
+
trend = get_word_trend(beta, vocab, word, topic_id=selected_topic)
|
| 255 |
+
color = color_cycle[i % len(color_cycle)]
|
| 256 |
+
fig.add_trace(go.Scatter(
|
| 257 |
+
x=time_labels,
|
| 258 |
+
y=trend,
|
| 259 |
+
name=word,
|
| 260 |
+
line=dict(color=color),
|
| 261 |
+
legendgroup=word,
|
| 262 |
+
showlegend=True
|
| 263 |
+
))
|
| 264 |
+
fig.update_layout(title="", xaxis_title="Year", yaxis_title="Importance")
|
| 265 |
+
st.plotly_chart(fig, use_container_width=True)
|
| 266 |
+
|
| 267 |
+
# ==============================================================================
|
| 268 |
+
# 3. 🔍 DOCUMENT RETRIEVAL & 📃 SUMMARIZATION
|
| 269 |
+
# ==============================================================================
|
| 270 |
+
st.markdown("---")
|
| 271 |
+
st.markdown("## 3️⃣ 🔍 Document Retrieval & 📃 Summarization")
|
| 272 |
+
st.info("Retrieve and summarize documents matching selected words and years.")
|
| 273 |
+
|
| 274 |
+
if selected_words:
|
| 275 |
+
st.markdown("#### 📊 Document Frequency Over Time")
|
| 276 |
+
selected_words_for_counts = st.multiselect(
|
| 277 |
+
"Select word(s) to show document frequencies over time",
|
| 278 |
+
options=selected_words,
|
| 279 |
+
default=selected_words[:3],
|
| 280 |
+
key="word_counts_multiselect"
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
if selected_words_for_counts:
|
| 284 |
+
color_cycle = pc.qualitative.Set2
|
| 285 |
+
bar_fig = go.Figure()
|
| 286 |
+
for i, word in enumerate(selected_words_for_counts):
|
| 287 |
+
doc_years, doc_counts = get_yearly_counts_for_word(index=index, word=word)
|
| 288 |
+
bar_fig.add_trace(go.Bar(
|
| 289 |
+
x=doc_years,
|
| 290 |
+
y=doc_counts,
|
| 291 |
+
name=word,
|
| 292 |
+
marker_color=color_cycle[i % len(color_cycle)],
|
| 293 |
+
opacity=0.85
|
| 294 |
+
))
|
| 295 |
+
bar_fig.update_layout(
|
| 296 |
+
barmode="group",
|
| 297 |
+
title="Document Frequency Over Time",
|
| 298 |
+
xaxis_title="Year",
|
| 299 |
+
yaxis_title="Document Count",
|
| 300 |
+
xaxis=dict(
|
| 301 |
+
tickmode='linear',
|
| 302 |
+
dtick=1,
|
| 303 |
+
tickformat='d'
|
| 304 |
+
),
|
| 305 |
+
bargap=0.2
|
| 306 |
+
)
|
| 307 |
+
st.plotly_chart(bar_fig, use_container_width=True)
|
| 308 |
+
|
| 309 |
+
st.markdown("#### 📄 Inspect Documents for Word-Year Pairs")
|
| 310 |
+
# selected_year = st.slider("Select year", min_value=int(time_labels[0]), max_value=int(time_labels[-1]), key="inspect_year_slider")
|
| 311 |
+
selected_year = st.selectbox(
|
| 312 |
+
"Select year",
|
| 313 |
+
options=time_labels, # Use the list of available time labels (years)
|
| 314 |
+
index=0, # Default to the first year in the list
|
| 315 |
+
key="inspect_year_selectbox"
|
| 316 |
+
)
|
| 317 |
+
collected_docs_raw = []
|
| 318 |
+
for word in selected_words_for_counts:
|
| 319 |
+
docs_for_word_year = get_all_documents_for_word_year(
|
| 320 |
+
index=index,
|
| 321 |
+
docs_file_path=docs_path,
|
| 322 |
+
word=word,
|
| 323 |
+
year=selected_year
|
| 324 |
+
)
|
| 325 |
+
for doc in docs_for_word_year:
|
| 326 |
+
doc["__word__"] = word
|
| 327 |
+
collected_docs_raw.extend(docs_for_word_year)
|
| 328 |
+
|
| 329 |
+
if collected_docs_raw:
|
| 330 |
+
st.session_state.collected_deduplicated_docs = deduplicate_docs(collected_docs_raw)
|
| 331 |
+
st.write(f"Found {len(collected_docs_raw)} matching documents, {len(st.session_state.collected_deduplicated_docs)} after deduplication.")
|
| 332 |
+
|
| 333 |
+
html_blocks = ""
|
| 334 |
+
for doc in st.session_state.collected_deduplicated_docs:
|
| 335 |
+
word = doc["__word__"]
|
| 336 |
+
full_text = html.escape(doc["text"])
|
| 337 |
+
snippet_text = extract_snippet(doc["text"], word)
|
| 338 |
+
highlighted_snippet = highlight_words(
|
| 339 |
+
snippet_text,
|
| 340 |
+
query_words=selected_words_for_counts,
|
| 341 |
+
lemma_to_forms=lemma_to_forms
|
| 342 |
+
)
|
| 343 |
+
html_blocks += f"""
|
| 344 |
+
<div style="margin-bottom: 14px; padding: 10px; background-color: #fffbe6; border: 1px solid #f0e6cc; border-radius: 6px;">
|
| 345 |
+
<div style="color: #333;"><strong>Match:</strong> {word} | <strong>Doc ID:</strong> {doc['id']} | <strong>Timestamp:</strong> {doc['timestamp']}</div>
|
| 346 |
+
<div style="margin-top: 4px; color: #444;"><em>Snippet:</em> {highlighted_snippet}</div>
|
| 347 |
+
<details style="margin-top: 4px;">
|
| 348 |
+
<summary style="cursor: pointer; color: #007acc;">Show full document</summary>
|
| 349 |
+
<pre style="white-space: pre-wrap; color: #111; background-color: #fffef5; padding: 8px; border: 1px solid #f0e6cc; border-radius: 4px;">{full_text}</pre>
|
| 350 |
+
</details>
|
| 351 |
+
</div>
|
| 352 |
+
"""
|
| 353 |
+
min_height = 120
|
| 354 |
+
max_height = 700
|
| 355 |
+
per_doc_height = 130
|
| 356 |
+
dynamic_height = min_height + per_doc_height * max(len(st.session_state.collected_deduplicated_docs) - 1, 0)
|
| 357 |
+
container_height = min(dynamic_height, max_height)
|
| 358 |
+
scrollable_html = f"""
|
| 359 |
+
<div style="overflow-y: auto; padding: 10px;
|
| 360 |
+
border: 1px solid #f0e6cc; border-radius: 6px;
|
| 361 |
+
background-color: #fffbe6; color: #222;
|
| 362 |
+
margin-bottom: 0;">
|
| 363 |
+
{html_blocks}
|
| 364 |
+
</div>
|
| 365 |
+
"""
|
| 366 |
+
components.html(scrollable_html, height=container_height, scrolling=True)
|
| 367 |
+
else:
|
| 368 |
+
st.warning("No documents found for the selected words and year.")
|
| 369 |
+
|
| 370 |
+
# ==============================================================================
|
| 371 |
+
# 4. 💬 CHAT ASSISTANT (Summary & Follow-up)
|
| 372 |
+
# ==============================================================================
|
| 373 |
+
st.markdown("---")
|
| 374 |
+
st.markdown("## 4️⃣ 💬 Chat Assistant")
|
| 375 |
+
st.info("Generate summaries from the inspected documents and ask follow-up questions.")
|
| 376 |
+
|
| 377 |
+
if "summary" not in st.session_state:
|
| 378 |
+
st.session_state.summary = None
|
| 379 |
+
if "context_for_followup" not in st.session_state:
|
| 380 |
+
st.session_state.context_for_followup = ""
|
| 381 |
+
if "followup_history" not in st.session_state:
|
| 382 |
+
st.session_state.followup_history = []
|
| 383 |
+
|
| 384 |
+
# MMR K selection
|
| 385 |
+
st.markdown(f"**Max documents for summarization (k):**")
|
| 386 |
+
st.markdown(f"The suggested maximum number of documents for summarization (k) based on the average document length and the selected LLM is **{suggested_max_k}**.")
|
| 387 |
+
mmr_k = st.slider(
|
| 388 |
+
"Select the maximum number of documents (k) for MMR (Maximum Marginal Relevance) selection for summarization.",
|
| 389 |
+
min_value=1,
|
| 390 |
+
max_value=20, # Set a reasonable max for k, can be adjusted
|
| 391 |
+
value=min(suggested_max_k, 20), # Use suggested_max_k as default, capped at 20
|
| 392 |
+
help="This value determines how many relevant and diverse documents will be selected for summarization."
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
if st.button("📃 Summarize These Documents"):
|
| 396 |
+
if st.session_state.get("collected_deduplicated_docs"):
|
| 397 |
+
st.session_state.summary = None
|
| 398 |
+
st.session_state.context_for_followup = ""
|
| 399 |
+
st.session_state.followup_history = []
|
| 400 |
+
with st.spinner("Selecting and summarizing documents..."):
|
| 401 |
+
summary, mmr_docs = summarize_multiword_docs(
|
| 402 |
+
selected_words_for_counts,
|
| 403 |
+
selected_year,
|
| 404 |
+
st.session_state.collected_deduplicated_docs,
|
| 405 |
+
llm,
|
| 406 |
+
k=mmr_k
|
| 407 |
+
)
|
| 408 |
+
st.session_state.summary = summary
|
| 409 |
+
st.session_state.context_for_followup = "\n".join(
|
| 410 |
+
f"Document {i+1}:\n{doc.page_content.strip()}" for i, doc in enumerate(mmr_docs)
|
| 411 |
+
)
|
| 412 |
+
st.session_state.followup_history.append(
|
| 413 |
+
{"role": "user", "content": f"Please summarize the context of the words '{', '.join(selected_words_for_counts)}' in {selected_year} based on the provided documents."}
|
| 414 |
+
)
|
| 415 |
+
st.session_state.followup_history.append(
|
| 416 |
+
{"role": "assistant", "content": st.session_state.summary}
|
| 417 |
+
)
|
| 418 |
+
st.success(f"✅ Summary generated from {len(mmr_docs)} MMR-selected documents.")
|
| 419 |
+
else:
|
| 420 |
+
st.warning("⚠️ No documents collected to summarize. Please inspect some documents first.")
|
| 421 |
+
|
| 422 |
+
if st.session_state.summary:
|
| 423 |
+
st.markdown(f"**Summary for words `{', '.join(selected_words_for_counts)}` in `{selected_year}`:**")
|
| 424 |
+
st.write(st.session_state.summary)
|
| 425 |
+
|
| 426 |
+
if st.checkbox("💬 Ask follow-up questions about this summary", key="enable_followup"):
|
| 427 |
+
with st.expander("View the documents used for this conversation"):
|
| 428 |
+
st.text_area("Context Documents", st.session_state.context_for_followup, height=200)
|
| 429 |
+
st.info("Ask a question based on the summary and the documents above.")
|
| 430 |
+
for msg in st.session_state.followup_history[2:]:
|
| 431 |
+
with st.chat_message(msg["role"], avatar="🧑" if msg["role"] == "user" else "🤖"):
|
| 432 |
+
st.markdown(msg["content"])
|
| 433 |
+
if user_query := st.chat_input("Ask a follow-up question..."):
|
| 434 |
+
with st.chat_message("user", avatar="🧑"):
|
| 435 |
+
st.markdown(user_query)
|
| 436 |
+
st.session_state.followup_history.append({"role": "user", "content": user_query})
|
| 437 |
+
with st.spinner("Thinking..."):
|
| 438 |
+
followup_response = ask_multiturn_followup(
|
| 439 |
+
history=st.session_state.followup_history,
|
| 440 |
+
question=user_query,
|
| 441 |
+
llm=llm,
|
| 442 |
+
context_texts=st.session_state.context_for_followup
|
| 443 |
+
)
|
| 444 |
+
st.session_state.followup_history.append({"role": "assistant", "content": followup_response})
|
| 445 |
+
if followup_response.startswith("[Error"):
|
| 446 |
+
st.error(followup_response)
|
| 447 |
+
else:
|
| 448 |
+
with st.chat_message("assistant", avatar="🤖"):
|
| 449 |
+
st.markdown(followup_response)
|
| 450 |
+
st.rerun()
|
assets/Logo_light.png
ADDED
|
Git LFS Details
|
backend/__init__.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# === Inference components ===
|
| 2 |
+
from .inference.process_beta import (
|
| 3 |
+
load_beta_matrix,
|
| 4 |
+
get_top_words_at_time,
|
| 5 |
+
get_top_words_over_time,
|
| 6 |
+
load_time_labels
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
from .inference.indexing_utils import load_index
|
| 10 |
+
from .inference.word_selector import (
|
| 11 |
+
get_interesting_words,
|
| 12 |
+
get_word_trend
|
| 13 |
+
)
|
| 14 |
+
from .inference.peak_detector import detect_peaks
|
| 15 |
+
from .inference.doc_retriever import (
|
| 16 |
+
load_length_stats,
|
| 17 |
+
get_yearly_counts_for_word,
|
| 18 |
+
get_all_documents_for_word_year,
|
| 19 |
+
deduplicate_docs,
|
| 20 |
+
extract_snippet,
|
| 21 |
+
highlight,
|
| 22 |
+
get_docs_by_ids,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
# === LLM components ===
|
| 26 |
+
from .llm_utils.label_generator import label_topic_temporal, get_topic_labels
|
| 27 |
+
from .llm_utils.token_utils import (
|
| 28 |
+
get_token_limit_for_model,
|
| 29 |
+
count_tokens,
|
| 30 |
+
estimate_avg_tokens_per_doc,
|
| 31 |
+
estimate_max_k,
|
| 32 |
+
estimate_max_k_fast
|
| 33 |
+
)
|
| 34 |
+
from .llm_utils.summarizer import (
|
| 35 |
+
summarize_docs,
|
| 36 |
+
summarize_multiword_docs,
|
| 37 |
+
ask_multiturn_followup
|
| 38 |
+
)
|
| 39 |
+
from .llm.llm_router import (
|
| 40 |
+
list_supported_models,
|
| 41 |
+
get_llm
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# === Dataset utilities ===
|
| 45 |
+
from .datasets import dynamic_dataset
|
| 46 |
+
from .datasets import preprocess
|
| 47 |
+
from .datasets.utils import logger, _utils
|
| 48 |
+
from .datasets.data import file_utils, download
|
| 49 |
+
|
| 50 |
+
# === Evaluation ===
|
| 51 |
+
from .evaluation.CoherenceModel_ttc import CoherenceModel_ttc
|
| 52 |
+
from .evaluation.eval import TopicQualityAssessor
|
| 53 |
+
|
| 54 |
+
# === Models ===
|
| 55 |
+
from .models.DETM import DETM
|
| 56 |
+
from .models.DTM_trainer import DTMTrainer
|
| 57 |
+
from .models.CFDTM.CFDTM import CFDTM
|
| 58 |
+
from .models.dynamic_trainer import DynamicTrainer
|
| 59 |
+
|
| 60 |
+
__all__ = [
|
| 61 |
+
# Inference
|
| 62 |
+
"load_beta_matrix", "load_time_labels", "get_top_words_at_time", "get_top_words_over_time",
|
| 63 |
+
"load_index", "get_interesting_words", "get_word_trend", "detect_peaks",
|
| 64 |
+
"load_length_stats", "get_yearly_counts_for_word", "get_all_documents_for_word_year",
|
| 65 |
+
"deduplicate_docs", "extract_snippet", "highlight", "get_docs_by_ids",
|
| 66 |
+
|
| 67 |
+
# LLM
|
| 68 |
+
"summarize_docs", "summarize_multiword_docs", "ask_multiturn_followup",
|
| 69 |
+
"get_token_limit_for_model", "list_supported_models", "get_llm",
|
| 70 |
+
"label_topic_temporal", "get_topic_labels", "count_tokens",
|
| 71 |
+
"estimate_avg_tokens_per_doc", "estimate_max_k", "estimate_max_k_fast",
|
| 72 |
+
|
| 73 |
+
# Dataset
|
| 74 |
+
"dynamic_dataset", "preprocess", "logger","_utils", "file_utils", "download",
|
| 75 |
+
|
| 76 |
+
# Evaluation
|
| 77 |
+
"CoherenceModel_ttc", "TopicQualityAssessor",
|
| 78 |
+
|
| 79 |
+
# Models
|
| 80 |
+
"DETM", "DTMTrainer", "CFDTM", "DynamicTrainer"
|
| 81 |
+
]
|
backend/datasets/_preprocess.py
ADDED
|
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import string
|
| 4 |
+
import gensim.downloader
|
| 5 |
+
from collections import Counter
|
| 6 |
+
import numpy as np
|
| 7 |
+
import scipy.sparse
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
from sklearn.feature_extraction.text import CountVectorizer
|
| 10 |
+
|
| 11 |
+
from backend.datasets.data import file_utils
|
| 12 |
+
from backend.datasets.utils._utils import get_stopwords_set
|
| 13 |
+
from backend.datasets.utils.logger import Logger
|
| 14 |
+
import json
|
| 15 |
+
import nltk
|
| 16 |
+
from nltk.stem import WordNetLemmatizer
|
| 17 |
+
|
| 18 |
+
logger = Logger("WARNING")
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
nltk.data.find('corpora/wordnet')
|
| 22 |
+
except LookupError:
|
| 23 |
+
nltk.download('wordnet', quiet=True)
|
| 24 |
+
try:
|
| 25 |
+
nltk.data.find('corpora/omw-1.4')
|
| 26 |
+
except LookupError:
|
| 27 |
+
nltk.download('omw-1.4', quiet=True)
|
| 28 |
+
|
| 29 |
+
# compile some regexes
|
| 30 |
+
punct_chars = list(set(string.punctuation) - set("'"))
|
| 31 |
+
punct_chars.sort()
|
| 32 |
+
punctuation = ''.join(punct_chars)
|
| 33 |
+
replace = re.compile('[%s]' % re.escape(punctuation))
|
| 34 |
+
alpha = re.compile('^[a-zA-Z_]+$')
|
| 35 |
+
alpha_or_num = re.compile('^[a-zA-Z_]+|[0-9_]+$')
|
| 36 |
+
alphanum = re.compile('^[a-zA-Z0-9_]+$')
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class Tokenizer:
|
| 40 |
+
def __init__(self,
|
| 41 |
+
stopwords="English",
|
| 42 |
+
keep_num=False,
|
| 43 |
+
keep_alphanum=False,
|
| 44 |
+
strip_html=False,
|
| 45 |
+
no_lower=False,
|
| 46 |
+
min_length=3,
|
| 47 |
+
lemmatize=True,
|
| 48 |
+
):
|
| 49 |
+
self.keep_num = keep_num
|
| 50 |
+
self.keep_alphanum = keep_alphanum
|
| 51 |
+
self.strip_html = strip_html
|
| 52 |
+
self.lower = not no_lower
|
| 53 |
+
self.min_length = min_length
|
| 54 |
+
|
| 55 |
+
self.stopword_set = get_stopwords_set(stopwords)
|
| 56 |
+
|
| 57 |
+
self.lemmatize = lemmatize
|
| 58 |
+
if lemmatize:
|
| 59 |
+
self.lemmatizer = WordNetLemmatizer()
|
| 60 |
+
|
| 61 |
+
def clean_text(self, text, strip_html=False, lower=True, keep_emails=False, keep_at_mentions=False):
|
| 62 |
+
# remove html tags
|
| 63 |
+
if strip_html:
|
| 64 |
+
text = re.sub(r'<[^>]+>', '', text)
|
| 65 |
+
else:
|
| 66 |
+
# replace angle brackets
|
| 67 |
+
text = re.sub(r'<', '(', text)
|
| 68 |
+
text = re.sub(r'>', ')', text)
|
| 69 |
+
# lower case
|
| 70 |
+
if lower:
|
| 71 |
+
text = text.lower()
|
| 72 |
+
# eliminate email addresses
|
| 73 |
+
if not keep_emails:
|
| 74 |
+
text = re.sub(r'\S+@\S+', ' ', text)
|
| 75 |
+
# eliminate @mentions
|
| 76 |
+
if not keep_at_mentions:
|
| 77 |
+
text = re.sub(r'\s@\S+', ' ', text)
|
| 78 |
+
# replace underscores with spaces
|
| 79 |
+
text = re.sub(r'_', ' ', text)
|
| 80 |
+
# break off single quotes at the ends of words
|
| 81 |
+
text = re.sub(r'\s\'', ' ', text)
|
| 82 |
+
text = re.sub(r'\'\s', ' ', text)
|
| 83 |
+
# remove periods
|
| 84 |
+
text = re.sub(r'\.', '', text)
|
| 85 |
+
# replace all other punctuation (except single quotes) with spaces
|
| 86 |
+
text = replace.sub(' ', text)
|
| 87 |
+
# remove single quotes
|
| 88 |
+
text = re.sub(r'\'', '', text)
|
| 89 |
+
# replace all whitespace with a single space
|
| 90 |
+
text = re.sub(r'\s', ' ', text)
|
| 91 |
+
# strip off spaces on either end
|
| 92 |
+
text = text.strip()
|
| 93 |
+
return text
|
| 94 |
+
|
| 95 |
+
def tokenize(self, text):
|
| 96 |
+
text = self.clean_text(text, self.strip_html, self.lower)
|
| 97 |
+
tokens = text.split()
|
| 98 |
+
|
| 99 |
+
tokens = ['_' if t in self.stopword_set else t for t in tokens]
|
| 100 |
+
|
| 101 |
+
# remove tokens that contain numbers
|
| 102 |
+
if not self.keep_alphanum and not self.keep_num:
|
| 103 |
+
tokens = [t if alpha.match(t) else '_' for t in tokens]
|
| 104 |
+
|
| 105 |
+
# or just remove tokens that contain a combination of letters and numbers
|
| 106 |
+
elif not self.keep_alphanum:
|
| 107 |
+
tokens = [t if alpha_or_num.match(t) else '_' for t in tokens]
|
| 108 |
+
|
| 109 |
+
# drop short tokens
|
| 110 |
+
if self.min_length > 0:
|
| 111 |
+
tokens = [t if len(t) >= self.min_length else '_' for t in tokens]
|
| 112 |
+
|
| 113 |
+
if getattr(self, "lemmatize", False):
|
| 114 |
+
tokens = [self.lemmatizer.lemmatize(t) if t != '_' else t for t in tokens]
|
| 115 |
+
|
| 116 |
+
unigrams = [t for t in tokens if t != '_']
|
| 117 |
+
return unigrams
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def make_word_embeddings(vocab):
|
| 121 |
+
glove_vectors = gensim.downloader.load('glove-wiki-gigaword-200')
|
| 122 |
+
word_embeddings = np.zeros((len(vocab), glove_vectors.vectors.shape[1]))
|
| 123 |
+
|
| 124 |
+
num_found = 0
|
| 125 |
+
|
| 126 |
+
try:
|
| 127 |
+
key_word_list = glove_vectors.index_to_key
|
| 128 |
+
except:
|
| 129 |
+
key_word_list = glove_vectors.index2word
|
| 130 |
+
|
| 131 |
+
for i, word in enumerate(tqdm(vocab, desc="loading word embeddings")):
|
| 132 |
+
if word in key_word_list:
|
| 133 |
+
word_embeddings[i] = glove_vectors[word]
|
| 134 |
+
num_found += 1
|
| 135 |
+
|
| 136 |
+
logger.info(f'number of found embeddings: {num_found}/{len(vocab)}')
|
| 137 |
+
|
| 138 |
+
return scipy.sparse.csr_matrix(word_embeddings)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class Preprocess:
|
| 142 |
+
def __init__(self,
|
| 143 |
+
tokenizer=None,
|
| 144 |
+
test_sample_size=None,
|
| 145 |
+
test_p=0.2,
|
| 146 |
+
stopwords="English",
|
| 147 |
+
min_doc_count=0,
|
| 148 |
+
max_doc_freq=1.0,
|
| 149 |
+
keep_num=False,
|
| 150 |
+
keep_alphanum=False,
|
| 151 |
+
strip_html=False,
|
| 152 |
+
no_lower=False,
|
| 153 |
+
min_length=3,
|
| 154 |
+
min_term=0,
|
| 155 |
+
vocab_size=None,
|
| 156 |
+
seed=42,
|
| 157 |
+
verbose=True,
|
| 158 |
+
lemmatize=True,
|
| 159 |
+
):
|
| 160 |
+
"""
|
| 161 |
+
Args:
|
| 162 |
+
test_sample_size:
|
| 163 |
+
Size of the test set.
|
| 164 |
+
test_p:
|
| 165 |
+
Proportion of the test set. This helps sample the train set based on the size of the test set.
|
| 166 |
+
stopwords:
|
| 167 |
+
List of stopwords to exclude.
|
| 168 |
+
min-doc-count:
|
| 169 |
+
Exclude words that occur in less than this number of documents.
|
| 170 |
+
max_doc_freq:
|
| 171 |
+
Exclude words that occur in more than this proportion of documents.
|
| 172 |
+
keep-num:
|
| 173 |
+
Keep tokens made of only numbers.
|
| 174 |
+
keep-alphanum:
|
| 175 |
+
Keep tokens made of a mixture of letters and numbers.
|
| 176 |
+
strip_html:
|
| 177 |
+
Strip HTML tags.
|
| 178 |
+
no-lower:
|
| 179 |
+
Do not lowercase text
|
| 180 |
+
min_length:
|
| 181 |
+
Minimum token length.
|
| 182 |
+
min_term:
|
| 183 |
+
Minimum term number
|
| 184 |
+
vocab-size:
|
| 185 |
+
Size of the vocabulary (by most common in the union of train and test sets, following above exclusions)
|
| 186 |
+
seed:
|
| 187 |
+
Random integer seed (only relevant for choosing test set)
|
| 188 |
+
lemmatize:
|
| 189 |
+
Whether to apply lemmatization to the tokens.
|
| 190 |
+
"""
|
| 191 |
+
|
| 192 |
+
self.test_sample_size = test_sample_size
|
| 193 |
+
self.min_doc_count = min_doc_count
|
| 194 |
+
self.max_doc_freq = max_doc_freq
|
| 195 |
+
self.min_term = min_term
|
| 196 |
+
self.test_p = test_p
|
| 197 |
+
self.vocab_size = vocab_size
|
| 198 |
+
self.seed = seed
|
| 199 |
+
|
| 200 |
+
if tokenizer is not None:
|
| 201 |
+
self.tokenizer = tokenizer
|
| 202 |
+
else:
|
| 203 |
+
self.tokenizer = Tokenizer(
|
| 204 |
+
stopwords,
|
| 205 |
+
keep_num,
|
| 206 |
+
keep_alphanum,
|
| 207 |
+
strip_html,
|
| 208 |
+
no_lower,
|
| 209 |
+
min_length,
|
| 210 |
+
lemmatize=lemmatize
|
| 211 |
+
).tokenize
|
| 212 |
+
|
| 213 |
+
if verbose:
|
| 214 |
+
logger.set_level("DEBUG")
|
| 215 |
+
else:
|
| 216 |
+
logger.set_level("WARNING")
|
| 217 |
+
|
| 218 |
+
def parse(self, texts, vocab):
|
| 219 |
+
if not isinstance(texts, list):
|
| 220 |
+
texts = [texts]
|
| 221 |
+
|
| 222 |
+
vocab_set = set(vocab)
|
| 223 |
+
parsed_texts = list()
|
| 224 |
+
for i, text in enumerate(tqdm(texts, desc="parsing texts")):
|
| 225 |
+
tokens = self.tokenizer(text)
|
| 226 |
+
tokens = [t for t in tokens if t in vocab_set]
|
| 227 |
+
parsed_texts.append(" ".join(tokens))
|
| 228 |
+
|
| 229 |
+
vectorizer = CountVectorizer(vocabulary=vocab, tokenizer=lambda x: x.split())
|
| 230 |
+
sparse_bow = vectorizer.fit_transform(parsed_texts)
|
| 231 |
+
return parsed_texts, sparse_bow
|
| 232 |
+
|
| 233 |
+
def preprocess_jsonlist(self, dataset_dir, label_name=None, use_partition=True):
|
| 234 |
+
if use_partition:
|
| 235 |
+
train_items = file_utils.read_jsonlist(os.path.join(dataset_dir, 'train.jsonlist'))
|
| 236 |
+
test_items = file_utils.read_jsonlist(os.path.join(dataset_dir, 'test.jsonlist'))
|
| 237 |
+
else:
|
| 238 |
+
raw_path = os.path.join(dataset_dir, 'docs.jsonl')
|
| 239 |
+
with open(raw_path, 'r', encoding='utf-8') as f:
|
| 240 |
+
train_items = [json.loads(line.strip()) for line in f if line.strip()]
|
| 241 |
+
test_items = []
|
| 242 |
+
|
| 243 |
+
logger.info(f"Found training documents {len(train_items)} testing documents {len(test_items)}")
|
| 244 |
+
|
| 245 |
+
# Initialize containers
|
| 246 |
+
raw_train_texts, train_labels, raw_train_times = [], [], []
|
| 247 |
+
raw_test_texts, test_labels, raw_test_times = [], [], []
|
| 248 |
+
|
| 249 |
+
# Process train items
|
| 250 |
+
for item in train_items:
|
| 251 |
+
raw_train_texts.append(item['text'])
|
| 252 |
+
raw_train_times.append(str(item['timestamp']))
|
| 253 |
+
if label_name and label_name in item:
|
| 254 |
+
train_labels.append(item[label_name])
|
| 255 |
+
|
| 256 |
+
# Process test items
|
| 257 |
+
for item in test_items:
|
| 258 |
+
raw_test_texts.append(item['text'])
|
| 259 |
+
raw_test_times.append(str(item['timestamp']))
|
| 260 |
+
if label_name and label_name in item:
|
| 261 |
+
test_labels.append(item[label_name])
|
| 262 |
+
|
| 263 |
+
# Create and apply time2id mapping
|
| 264 |
+
all_times = sorted(set(raw_train_times + raw_test_times))
|
| 265 |
+
time2id = {t: i for i, t in enumerate(all_times)}
|
| 266 |
+
|
| 267 |
+
train_times = np.array([time2id[t] for t in raw_train_times], dtype=np.int32)
|
| 268 |
+
test_times = np.array([time2id[t] for t in raw_test_times], dtype=np.int32) if raw_test_times else None
|
| 269 |
+
|
| 270 |
+
# Preprocess and get sample indices
|
| 271 |
+
rst = self.preprocess(raw_train_texts, train_labels, raw_test_texts, test_labels)
|
| 272 |
+
train_idx = rst.get("train_idx")
|
| 273 |
+
test_idx = rst.get("test_idx")
|
| 274 |
+
|
| 275 |
+
# Add filtered timestamps to result for saving later
|
| 276 |
+
rst["train_times"] = train_times[train_idx]
|
| 277 |
+
if test_times is not None and test_idx is not None:
|
| 278 |
+
rst["test_times"] = test_times[test_idx]
|
| 279 |
+
|
| 280 |
+
# Add time2id to result dict
|
| 281 |
+
rst["time2id"] = time2id
|
| 282 |
+
|
| 283 |
+
return rst
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def convert_labels(self, train_labels, test_labels):
|
| 287 |
+
if train_labels:
|
| 288 |
+
label_list = list(set(train_labels).union(set(test_labels)))
|
| 289 |
+
label_list.sort()
|
| 290 |
+
n_labels = len(label_list)
|
| 291 |
+
label2id = dict(zip(label_list, range(n_labels)))
|
| 292 |
+
|
| 293 |
+
logger.info(f"label2id: {label2id}")
|
| 294 |
+
|
| 295 |
+
train_labels = [label2id[label] for label in train_labels]
|
| 296 |
+
|
| 297 |
+
if test_labels:
|
| 298 |
+
test_labels = [label2id[label] for label in test_labels]
|
| 299 |
+
|
| 300 |
+
return train_labels, test_labels
|
| 301 |
+
|
| 302 |
+
def preprocess(
|
| 303 |
+
self,
|
| 304 |
+
raw_train_texts,
|
| 305 |
+
train_labels=None,
|
| 306 |
+
raw_test_texts=None,
|
| 307 |
+
test_labels=None,
|
| 308 |
+
pretrained_WE=True
|
| 309 |
+
):
|
| 310 |
+
np.random.seed(self.seed)
|
| 311 |
+
|
| 312 |
+
train_texts = list()
|
| 313 |
+
test_texts = list()
|
| 314 |
+
word_counts = Counter()
|
| 315 |
+
doc_counts_counter = Counter()
|
| 316 |
+
|
| 317 |
+
train_labels, test_labels = self.convert_labels(train_labels, test_labels)
|
| 318 |
+
|
| 319 |
+
for text in tqdm(raw_train_texts, desc="loading train texts"):
|
| 320 |
+
tokens = self.tokenizer(text)
|
| 321 |
+
word_counts.update(tokens)
|
| 322 |
+
doc_counts_counter.update(set(tokens))
|
| 323 |
+
parsed_text = ' '.join(tokens)
|
| 324 |
+
train_texts.append(parsed_text)
|
| 325 |
+
|
| 326 |
+
if raw_test_texts:
|
| 327 |
+
for text in tqdm(raw_test_texts, desc="loading test texts"):
|
| 328 |
+
tokens = self.tokenizer(text)
|
| 329 |
+
word_counts.update(tokens)
|
| 330 |
+
doc_counts_counter.update(set(tokens))
|
| 331 |
+
parsed_text = ' '.join(tokens)
|
| 332 |
+
test_texts.append(parsed_text)
|
| 333 |
+
|
| 334 |
+
words, doc_counts = zip(*doc_counts_counter.most_common())
|
| 335 |
+
doc_freqs = np.array(doc_counts) / float(len(train_texts) + len(test_texts))
|
| 336 |
+
|
| 337 |
+
vocab = [word for i, word in enumerate(words) if doc_counts[i] >= self.min_doc_count and doc_freqs[i] <= self.max_doc_freq]
|
| 338 |
+
|
| 339 |
+
# filter vocabulary
|
| 340 |
+
if self.vocab_size is not None:
|
| 341 |
+
vocab = vocab[:self.vocab_size]
|
| 342 |
+
|
| 343 |
+
vocab.sort()
|
| 344 |
+
|
| 345 |
+
train_idx = [i for i, text in enumerate(train_texts) if len(text.split()) >= self.min_term]
|
| 346 |
+
train_idx = np.asarray(train_idx)
|
| 347 |
+
|
| 348 |
+
if raw_test_texts is not None:
|
| 349 |
+
test_idx = [i for i, text in enumerate(test_texts) if len(text.split()) >= self.min_term]
|
| 350 |
+
test_idx = np.asarray(test_idx)
|
| 351 |
+
else:
|
| 352 |
+
test_idx = None
|
| 353 |
+
|
| 354 |
+
# randomly sample
|
| 355 |
+
if self.test_sample_size and raw_test_texts is not None:
|
| 356 |
+
logger.info("sample train and test sets...")
|
| 357 |
+
|
| 358 |
+
train_num = len(train_idx)
|
| 359 |
+
test_num = len(test_idx)
|
| 360 |
+
test_sample_size = min(test_num, self.test_sample_size)
|
| 361 |
+
train_sample_size = int((test_sample_size / self.test_p) * (1 - self.test_p))
|
| 362 |
+
if train_sample_size > train_num:
|
| 363 |
+
test_sample_size = int((train_num / (1 - self.test_p)) * self.test_p)
|
| 364 |
+
train_sample_size = train_num
|
| 365 |
+
|
| 366 |
+
train_idx = train_idx[np.sort(np.random.choice(train_num, train_sample_size, replace=False))]
|
| 367 |
+
test_idx = test_idx[np.sort(np.random.choice(test_num, test_sample_size, replace=False))]
|
| 368 |
+
|
| 369 |
+
logger.info(f"sampled train size: {len(train_idx)}")
|
| 370 |
+
logger.info(f"sampled test size: {len(test_idx)}")
|
| 371 |
+
|
| 372 |
+
train_texts, train_bow = self.parse([train_texts[i] for i in train_idx], vocab)
|
| 373 |
+
|
| 374 |
+
rst = {
|
| 375 |
+
'vocab': vocab,
|
| 376 |
+
'train_bow': train_bow,
|
| 377 |
+
"train_texts": train_texts,
|
| 378 |
+
"train_idx": train_idx, # <--- NEW: indices of kept train samples
|
| 379 |
+
}
|
| 380 |
+
|
| 381 |
+
if train_labels:
|
| 382 |
+
rst['train_labels'] = np.asarray(train_labels)[train_idx]
|
| 383 |
+
|
| 384 |
+
logger.info(f"Real vocab size: {len(vocab)}")
|
| 385 |
+
logger.info(f"Real training size: {len(train_texts)} \t avg length: {rst['train_bow'].sum() / len(train_texts):.3f}")
|
| 386 |
+
|
| 387 |
+
if raw_test_texts:
|
| 388 |
+
rst['test_texts'], rst['test_bow'] = self.parse(np.asarray(test_texts)[test_idx].tolist(), vocab)
|
| 389 |
+
rst["test_idx"] = test_idx # <--- NEW: indices of kept test samples
|
| 390 |
+
|
| 391 |
+
if test_labels:
|
| 392 |
+
rst['test_labels'] = np.asarray(test_labels)[test_idx]
|
| 393 |
+
|
| 394 |
+
logger.info(f"Real testing size: {len(rst['test_texts'])} \t avg length: {rst['test_bow'].sum() / len(rst['test_texts']):.3f}")
|
| 395 |
+
|
| 396 |
+
if pretrained_WE:
|
| 397 |
+
rst['word_embeddings'] = make_word_embeddings(vocab)
|
| 398 |
+
|
| 399 |
+
return rst
|
| 400 |
+
|
| 401 |
+
def save(
|
| 402 |
+
self,
|
| 403 |
+
output_dir,
|
| 404 |
+
vocab,
|
| 405 |
+
train_texts,
|
| 406 |
+
train_bow,
|
| 407 |
+
word_embeddings=None,
|
| 408 |
+
train_labels=None,
|
| 409 |
+
test_texts=None,
|
| 410 |
+
test_bow=None,
|
| 411 |
+
test_labels=None,
|
| 412 |
+
train_times=None,
|
| 413 |
+
test_times=None,
|
| 414 |
+
time2id=None # <-- new parameter
|
| 415 |
+
):
|
| 416 |
+
file_utils.make_dir(output_dir)
|
| 417 |
+
|
| 418 |
+
file_utils.save_text(vocab, f"{output_dir}/vocab.txt")
|
| 419 |
+
file_utils.save_text(train_texts, f"{output_dir}/train_texts.txt")
|
| 420 |
+
scipy.sparse.save_npz(f"{output_dir}/train_bow.npz", scipy.sparse.csr_matrix(train_bow))
|
| 421 |
+
|
| 422 |
+
if word_embeddings is not None:
|
| 423 |
+
scipy.sparse.save_npz(f"{output_dir}/word_embeddings.npz", word_embeddings)
|
| 424 |
+
|
| 425 |
+
if train_labels:
|
| 426 |
+
np.savetxt(f"{output_dir}/train_labels.txt", train_labels, fmt='%i')
|
| 427 |
+
|
| 428 |
+
if train_times is not None:
|
| 429 |
+
np.savetxt(f"{output_dir}/train_times.txt", train_times, fmt='%i')
|
| 430 |
+
|
| 431 |
+
if test_bow is not None:
|
| 432 |
+
scipy.sparse.save_npz(f"{output_dir}/test_bow.npz", scipy.sparse.csr_matrix(test_bow))
|
| 433 |
+
|
| 434 |
+
if test_texts is not None:
|
| 435 |
+
file_utils.save_text(test_texts, f"{output_dir}/test_texts.txt")
|
| 436 |
+
|
| 437 |
+
if test_labels:
|
| 438 |
+
np.savetxt(f"{output_dir}/test_labels.txt", test_labels, fmt='%i')
|
| 439 |
+
|
| 440 |
+
if test_times is not None:
|
| 441 |
+
np.savetxt(f"{output_dir}/test_times.txt", test_times, fmt='%i')
|
| 442 |
+
|
| 443 |
+
# Save time2id mapping if provided
|
| 444 |
+
if time2id is not None:
|
| 445 |
+
with open(f"{output_dir}/time2id.txt", "w", encoding="utf-8") as f:
|
| 446 |
+
json.dump(time2id, f, indent=2)
|
| 447 |
+
|
backend/datasets/data/download.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import zipfile
|
| 3 |
+
from torchvision.datasets.utils import download_url
|
| 4 |
+
from backend.datasets.utils.logger import Logger
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
logger = Logger("WARNING")
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def download_dataset(dataset_name, cache_path="~/.topmost"):
|
| 11 |
+
cache_path = os.path.expanduser(cache_path)
|
| 12 |
+
raw_filename = f'{dataset_name}.zip'
|
| 13 |
+
|
| 14 |
+
if dataset_name in ['Wikitext-103']:
|
| 15 |
+
# download from Git LFS.
|
| 16 |
+
zipped_dataset_url = f"https://media.githubusercontent.com/media/BobXWu/TopMost/main/data/{raw_filename}"
|
| 17 |
+
else:
|
| 18 |
+
zipped_dataset_url = f"https://raw.githubusercontent.com/BobXWu/TopMost/master/data/{raw_filename}"
|
| 19 |
+
|
| 20 |
+
logger.info(zipped_dataset_url)
|
| 21 |
+
|
| 22 |
+
download_url(zipped_dataset_url, root=cache_path, filename=raw_filename, md5=None)
|
| 23 |
+
|
| 24 |
+
path = f'{cache_path}/{raw_filename}'
|
| 25 |
+
with zipfile.ZipFile(path, 'r') as zip_ref:
|
| 26 |
+
zip_ref.extractall(cache_path)
|
| 27 |
+
|
| 28 |
+
os.remove(path)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
if __name__ == '__main__':
|
| 32 |
+
download_dataset('20NG')
|
backend/datasets/data/file_utils.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def make_dir(path):
|
| 6 |
+
os.makedirs(path, exist_ok=True)
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def read_text(path):
|
| 10 |
+
texts = list()
|
| 11 |
+
with open(path, 'r', encoding='utf-8', errors='ignore') as file:
|
| 12 |
+
for line in file:
|
| 13 |
+
texts.append(line.strip())
|
| 14 |
+
return texts
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def save_text(texts, path):
|
| 18 |
+
with open(path, 'w', encoding='utf-8') as file:
|
| 19 |
+
for text in texts:
|
| 20 |
+
file.write(text.strip() + '\n')
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def read_jsonlist(path):
|
| 24 |
+
data = list()
|
| 25 |
+
with open(path, 'r', encoding='utf-8') as input_file:
|
| 26 |
+
for line in input_file:
|
| 27 |
+
data.append(json.loads(line))
|
| 28 |
+
return data
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def save_jsonlist(list_of_json_objects, path, sort_keys=True):
|
| 32 |
+
with open(path, 'w', encoding='utf-8') as output_file:
|
| 33 |
+
for obj in list_of_json_objects:
|
| 34 |
+
output_file.write(json.dumps(obj, sort_keys=sort_keys) + '\n')
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def split_text_word(texts):
|
| 38 |
+
texts = [text.split() for text in texts]
|
| 39 |
+
return texts
|
backend/datasets/dynamic_dataset.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import Dataset, DataLoader
|
| 3 |
+
import numpy as np
|
| 4 |
+
import scipy.sparse
|
| 5 |
+
import scipy.io
|
| 6 |
+
from backend.datasets.data import file_utils
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class _SequentialDataset(Dataset):
|
| 10 |
+
def __init__(self, bow, times, time_wordfreq):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.bow = bow
|
| 13 |
+
self.times = times
|
| 14 |
+
self.time_wordfreq = time_wordfreq
|
| 15 |
+
|
| 16 |
+
def __len__(self):
|
| 17 |
+
return len(self.bow)
|
| 18 |
+
|
| 19 |
+
def __getitem__(self, index):
|
| 20 |
+
return_dict = {
|
| 21 |
+
'bow': self.bow[index],
|
| 22 |
+
'times': self.times[index],
|
| 23 |
+
'time_wordfreq': self.time_wordfreq[self.times[index]],
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
return return_dict
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class DynamicDataset:
|
| 30 |
+
def __init__(self, dataset_dir, batch_size=200, read_labels=False, use_partition=False, device='cuda', as_tensor=True):
|
| 31 |
+
|
| 32 |
+
self.load_data(dataset_dir, read_labels, use_partition)
|
| 33 |
+
|
| 34 |
+
self.vocab_size = len(self.vocab)
|
| 35 |
+
self.train_size = len(self.train_bow)
|
| 36 |
+
self.num_times = int(self.train_times.max()) + 1 # assuming train_times is a numpy array
|
| 37 |
+
self.train_time_wordfreq = self.get_time_wordfreq(self.train_bow, self.train_times)
|
| 38 |
+
|
| 39 |
+
print('train size: ', len(self.train_bow))
|
| 40 |
+
if use_partition:
|
| 41 |
+
print('test size: ', len(self.test_bow))
|
| 42 |
+
print('vocab size: ', len(self.vocab))
|
| 43 |
+
print('average length: {:.3f}'.format(self.train_bow.sum(1).mean().item()))
|
| 44 |
+
print('num of each time slice: ', self.num_times, np.bincount(self.train_times))
|
| 45 |
+
|
| 46 |
+
if as_tensor:
|
| 47 |
+
self.train_bow = torch.from_numpy(self.train_bow).float().to(device)
|
| 48 |
+
self.train_times = torch.from_numpy(self.train_times).long().to(device)
|
| 49 |
+
self.train_time_wordfreq = torch.from_numpy(self.train_time_wordfreq).float().to(device)
|
| 50 |
+
|
| 51 |
+
if use_partition:
|
| 52 |
+
self.test_bow = torch.from_numpy(self.test_bow).float().to(device)
|
| 53 |
+
self.test_times = torch.from_numpy(self.test_times).long().to(device)
|
| 54 |
+
|
| 55 |
+
self.train_dataset = _SequentialDataset(self.train_bow, self.train_times, self.train_time_wordfreq)
|
| 56 |
+
|
| 57 |
+
if use_partition:
|
| 58 |
+
self.test_dataset = _SequentialDataset(self.test_bow, self.test_times, self.train_time_wordfreq)
|
| 59 |
+
|
| 60 |
+
self.train_dataloader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True)
|
| 61 |
+
|
| 62 |
+
def load_data(self, path, read_labels, use_partition=False):
|
| 63 |
+
self.train_bow = scipy.sparse.load_npz(f'{path}/train_bow.npz').toarray().astype('float32')
|
| 64 |
+
self.train_texts = file_utils.read_text(f'{path}/train_texts.txt')
|
| 65 |
+
self.train_times = np.loadtxt(f'{path}/train_times.txt').astype('int32')
|
| 66 |
+
self.vocab = file_utils.read_text(f'{path}/vocab.txt')
|
| 67 |
+
self.word_embeddings = scipy.sparse.load_npz(f'{path}/word_embeddings.npz').toarray().astype('float32')
|
| 68 |
+
|
| 69 |
+
self.pretrained_WE = self.word_embeddings # preserve compatibility
|
| 70 |
+
|
| 71 |
+
if read_labels:
|
| 72 |
+
self.train_labels = np.loadtxt(f'{path}/train_labels.txt').astype('int32')
|
| 73 |
+
|
| 74 |
+
if use_partition:
|
| 75 |
+
self.test_bow = scipy.sparse.load_npz(f'{path}/test_bow.npz').toarray().astype('float32')
|
| 76 |
+
self.test_texts = file_utils.read_text(f'{path}/test_texts.txt')
|
| 77 |
+
self.test_times = np.loadtxt(f'{path}/test_times.txt').astype('int32')
|
| 78 |
+
if read_labels:
|
| 79 |
+
self.test_labels = np.loadtxt(f'{path}/test_labels.txt').astype('int32')
|
| 80 |
+
|
| 81 |
+
# word frequency at each time slice.
|
| 82 |
+
def get_time_wordfreq(self, bow, times):
|
| 83 |
+
train_time_wordfreq = np.zeros((self.num_times, self.vocab_size))
|
| 84 |
+
for time in range(self.num_times):
|
| 85 |
+
idx = np.where(times == time)[0]
|
| 86 |
+
train_time_wordfreq[time] += bow[idx].sum(0)
|
| 87 |
+
cnt_times = np.bincount(times)
|
| 88 |
+
|
| 89 |
+
train_time_wordfreq = train_time_wordfreq / cnt_times[:, np.newaxis]
|
| 90 |
+
return train_time_wordfreq
|
backend/datasets/preprocess.py
ADDED
|
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import numpy as np
|
| 4 |
+
from collections import OrderedDict
|
| 5 |
+
import tempfile
|
| 6 |
+
import gensim.downloader
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from backend.datasets.utils.logger import Logger
|
| 9 |
+
import scipy.sparse
|
| 10 |
+
from gensim.models.phrases import Phrases, Phraser
|
| 11 |
+
from typing import List, Union
|
| 12 |
+
from octis.preprocessing.preprocessing import Preprocessing
|
| 13 |
+
|
| 14 |
+
logger = Logger("WARNING")
|
| 15 |
+
|
| 16 |
+
class Preprocessor:
|
| 17 |
+
def __init__(self,
|
| 18 |
+
docs_jsonl_path: str,
|
| 19 |
+
output_folder: str,
|
| 20 |
+
use_partition: bool = False,
|
| 21 |
+
use_bigrams: bool = False,
|
| 22 |
+
min_count_bigram: int = 5,
|
| 23 |
+
threshold_bigram: int = 10,
|
| 24 |
+
remove_punctuation: bool = True,
|
| 25 |
+
lemmatize: bool = True,
|
| 26 |
+
stopword_list: Union[str, List[str]] = None,
|
| 27 |
+
min_chars: int = 3,
|
| 28 |
+
min_words_docs: int = 10,
|
| 29 |
+
min_df: Union[int, float] = 0.0,
|
| 30 |
+
max_df: Union[int, float] = 1.0,
|
| 31 |
+
max_features: int = None,
|
| 32 |
+
language: str = 'english'):
|
| 33 |
+
|
| 34 |
+
self.docs_jsonl_path = docs_jsonl_path
|
| 35 |
+
self.output_folder = output_folder
|
| 36 |
+
self.use_partition = use_partition
|
| 37 |
+
self.use_bigrams = use_bigrams
|
| 38 |
+
self.min_count_bigram = min_count_bigram
|
| 39 |
+
self.threshold_bigram = threshold_bigram
|
| 40 |
+
|
| 41 |
+
os.makedirs(self.output_folder, exist_ok=True)
|
| 42 |
+
|
| 43 |
+
self.preprocessing_params = {
|
| 44 |
+
'remove_punctuation': remove_punctuation,
|
| 45 |
+
'lemmatize': lemmatize,
|
| 46 |
+
'stopword_list': stopword_list,
|
| 47 |
+
'min_chars': min_chars,
|
| 48 |
+
'min_words_docs': min_words_docs,
|
| 49 |
+
'min_df': min_df,
|
| 50 |
+
'max_df': max_df,
|
| 51 |
+
'max_features': max_features,
|
| 52 |
+
'language': language
|
| 53 |
+
}
|
| 54 |
+
self.preprocessor_octis = Preprocessing(**self.preprocessing_params)
|
| 55 |
+
|
| 56 |
+
def _load_data_to_temp_files(self):
|
| 57 |
+
"""Loads data from JSONL and writes to temporary files for OCTIS preprocessor."""
|
| 58 |
+
raw_texts = []
|
| 59 |
+
raw_timestamps = []
|
| 60 |
+
raw_labels = []
|
| 61 |
+
has_labels = False
|
| 62 |
+
|
| 63 |
+
with open(self.docs_jsonl_path, 'r', encoding='utf-8') as f:
|
| 64 |
+
for line in f:
|
| 65 |
+
data = json.loads(line.strip())
|
| 66 |
+
# Remove newlines from text
|
| 67 |
+
clean_text = data.get('text', '').replace('\n', ' ').replace('\r', ' ')
|
| 68 |
+
clean_text = " ".join(clean_text.split())
|
| 69 |
+
raw_texts.append(clean_text)
|
| 70 |
+
raw_timestamps.append(data.get('timestamp', ''))
|
| 71 |
+
label = data.get('label', '')
|
| 72 |
+
if label:
|
| 73 |
+
has_labels = True
|
| 74 |
+
raw_labels.append(label)
|
| 75 |
+
|
| 76 |
+
# Create temporary files
|
| 77 |
+
temp_dir = tempfile.mkdtemp()
|
| 78 |
+
temp_docs_path = os.path.join(temp_dir, "temp_docs.txt")
|
| 79 |
+
temp_labels_path = None
|
| 80 |
+
|
| 81 |
+
with open(temp_docs_path, 'w', encoding='utf-8') as f_docs:
|
| 82 |
+
for text in raw_texts:
|
| 83 |
+
f_docs.write(f"{text}\n")
|
| 84 |
+
|
| 85 |
+
if has_labels:
|
| 86 |
+
temp_labels_path = os.path.join(temp_dir, "temp_labels.txt")
|
| 87 |
+
with open(temp_labels_path, 'w', encoding='utf-8') as f_labels:
|
| 88 |
+
for label in raw_labels:
|
| 89 |
+
f_labels.write(f"{label}\n")
|
| 90 |
+
|
| 91 |
+
print(f"Loaded {len(raw_texts)} raw documents and created temporary files in {temp_dir}.")
|
| 92 |
+
return raw_texts, raw_timestamps, raw_labels, temp_docs_path, temp_labels_path, temp_dir
|
| 93 |
+
|
| 94 |
+
def _make_word_embeddings(self, vocab):
|
| 95 |
+
"""
|
| 96 |
+
Generates word embeddings for the given vocabulary using GloVe.
|
| 97 |
+
For n-grams (e.g., "wordA_wordB", "wordX_wordY_wordZ" for n>=2),
|
| 98 |
+
the resultant embedding is the sum of the embeddings of its constituent
|
| 99 |
+
single words (wordA + wordB + ...).
|
| 100 |
+
"""
|
| 101 |
+
print("Loading GloVe word embeddings...")
|
| 102 |
+
glove_vectors = gensim.downloader.load('glove-wiki-gigaword-200')
|
| 103 |
+
|
| 104 |
+
# Initialize word_embeddings matrix with zeros.
|
| 105 |
+
# This ensures that words not found (single or n-gram constituents)
|
| 106 |
+
# will have a zero vector embedding.
|
| 107 |
+
word_embeddings = np.zeros((len(vocab), glove_vectors.vectors.shape[1]), dtype=np.float32)
|
| 108 |
+
|
| 109 |
+
num_found = 0
|
| 110 |
+
|
| 111 |
+
try:
|
| 112 |
+
# Using a set for key_word_list for O(1) average time complexity lookup
|
| 113 |
+
key_word_list = set(glove_vectors.index_to_key)
|
| 114 |
+
except AttributeError: # For older gensim versions
|
| 115 |
+
key_word_list = set(glove_vectors.index2word)
|
| 116 |
+
|
| 117 |
+
print("Generating word embeddings for vocabulary (including n-grams)...")
|
| 118 |
+
for i, word in enumerate(tqdm(vocab, desc="Processing vocabulary words")):
|
| 119 |
+
if '_' in word: # Check if it's a potential n-gram (n >= 2)
|
| 120 |
+
parts = word.split('_')
|
| 121 |
+
|
| 122 |
+
# Check if *all* constituent words are present in GloVe
|
| 123 |
+
all_parts_in_glove = True
|
| 124 |
+
for part in parts:
|
| 125 |
+
if part not in key_word_list:
|
| 126 |
+
all_parts_in_glove = False
|
| 127 |
+
break # One part not found, stop checking
|
| 128 |
+
|
| 129 |
+
if all_parts_in_glove:
|
| 130 |
+
# If all parts are found, sum their embeddings
|
| 131 |
+
resultant_vector = np.zeros(glove_vectors.vectors.shape[1], dtype=np.float32)
|
| 132 |
+
for part in parts:
|
| 133 |
+
resultant_vector += glove_vectors[part]
|
| 134 |
+
|
| 135 |
+
word_embeddings[i] = resultant_vector
|
| 136 |
+
num_found += 1
|
| 137 |
+
# Else: one or more constituent words not found, embedding remains zero
|
| 138 |
+
else: # It's a single word (n=1)
|
| 139 |
+
if word in key_word_list:
|
| 140 |
+
word_embeddings[i] = glove_vectors[word]
|
| 141 |
+
num_found += 1
|
| 142 |
+
# Else: single word not found, embedding remains zero
|
| 143 |
+
|
| 144 |
+
logger.info(f'Number of found embeddings (including n-grams): {num_found}/{len(vocab)}')
|
| 145 |
+
return word_embeddings # Return as dense NumPy array
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def _save_doc_length_stats(self, filepath: str, output_path: str):
|
| 149 |
+
doc_lengths = []
|
| 150 |
+
try:
|
| 151 |
+
with open(filepath, 'r', encoding='utf-8') as f:
|
| 152 |
+
for line in f:
|
| 153 |
+
doc = line.strip()
|
| 154 |
+
if doc:
|
| 155 |
+
doc_lengths.append(len(doc))
|
| 156 |
+
except Exception as e:
|
| 157 |
+
print(f"Error processing '{filepath}': {e}")
|
| 158 |
+
return
|
| 159 |
+
|
| 160 |
+
if not doc_lengths:
|
| 161 |
+
print(f"No documents found in '{filepath}'.")
|
| 162 |
+
return
|
| 163 |
+
|
| 164 |
+
stats = {
|
| 165 |
+
"avg_len": float(np.mean(doc_lengths)),
|
| 166 |
+
"std_len": float(np.std(doc_lengths)),
|
| 167 |
+
"max_len": int(np.max(doc_lengths)),
|
| 168 |
+
"min_len": int(np.min(doc_lengths)),
|
| 169 |
+
"num_docs": int(len(doc_lengths))
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
with open(output_path, 'w', encoding='utf-8') as f:
|
| 173 |
+
json.dump(stats, f, indent=4)
|
| 174 |
+
print(f"Saved document length stats to: {output_path}")
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def preprocess(self):
|
| 178 |
+
print("Loading data and creating temporary files for OCTIS...")
|
| 179 |
+
_, raw_timestamps, _, temp_docs_path, temp_labels_path, temp_dir = \
|
| 180 |
+
self._load_data_to_temp_files()
|
| 181 |
+
|
| 182 |
+
print("Starting OCTIS pre-processing using file paths and specified parameters...")
|
| 183 |
+
octis_dataset = self.preprocessor_octis.preprocess_dataset(
|
| 184 |
+
documents_path=temp_docs_path,
|
| 185 |
+
labels_path=temp_labels_path
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# Clean up temporary files immediately
|
| 189 |
+
os.remove(temp_docs_path)
|
| 190 |
+
if temp_labels_path:
|
| 191 |
+
os.remove(temp_labels_path)
|
| 192 |
+
os.rmdir(temp_dir)
|
| 193 |
+
print(f"Temporary files in {temp_dir} cleaned up.")
|
| 194 |
+
|
| 195 |
+
# --- Proxy: Save __original_indexes and then manually load it ---
|
| 196 |
+
temp_indexes_dir = tempfile.mkdtemp()
|
| 197 |
+
temp_indexes_file = os.path.join(temp_indexes_dir, "temp_original_indexes.txt")
|
| 198 |
+
|
| 199 |
+
print(f"Saving __original_indexes to {temp_indexes_file}...")
|
| 200 |
+
octis_dataset._save_document_indexes(temp_indexes_file)
|
| 201 |
+
|
| 202 |
+
# Manually load the indexes from the file
|
| 203 |
+
original_indexes_after_octis = []
|
| 204 |
+
with open(temp_indexes_file, 'r') as f_indexes:
|
| 205 |
+
for line in f_indexes:
|
| 206 |
+
original_indexes_after_octis.append(int(line.strip())) # Read as int
|
| 207 |
+
|
| 208 |
+
# Clean up the temporary indexes file and its directory
|
| 209 |
+
os.remove(temp_indexes_file)
|
| 210 |
+
os.rmdir(temp_indexes_dir)
|
| 211 |
+
print("Temporary indexes file cleaned up.")
|
| 212 |
+
# --- End Proxy ---
|
| 213 |
+
|
| 214 |
+
# Get processed data from OCTIS Dataset object
|
| 215 |
+
processed_corpus_octis_list = octis_dataset.get_corpus() # List of list of tokens
|
| 216 |
+
processed_labels_octis = octis_dataset.get_labels() # List of labels
|
| 217 |
+
|
| 218 |
+
print("Max index in original_indexes_after_octis:", max(original_indexes_after_octis))
|
| 219 |
+
print("Length of raw_timestamps:", len(raw_timestamps))
|
| 220 |
+
|
| 221 |
+
# Filter timestamps based on documents that survived OCTIS preprocessing
|
| 222 |
+
filtered_timestamps = [raw_timestamps[i] for i in original_indexes_after_octis]
|
| 223 |
+
|
| 224 |
+
print(f"OCTIS preprocessing complete. {len(processed_corpus_octis_list)} documents remaining.")
|
| 225 |
+
|
| 226 |
+
if self.use_bigrams:
|
| 227 |
+
print("Generating bigrams with Gensim...")
|
| 228 |
+
phrases = Phrases(processed_corpus_octis_list, min_count=self.min_count_bigram, threshold=self.threshold_bigram)
|
| 229 |
+
bigram_phraser = Phraser(phrases)
|
| 230 |
+
bigrammed_corpus_list = [bigram_phraser[doc] for doc in processed_corpus_octis_list]
|
| 231 |
+
print("Bigram generation complete.")
|
| 232 |
+
else:
|
| 233 |
+
print("Skipping bigram generation as 'use_bigrams' is False.")
|
| 234 |
+
bigrammed_corpus_list = processed_corpus_octis_list # Use the original processed list
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
# Convert back to list of strings for easier handling if needed later, but keep as list of lists for BOW
|
| 238 |
+
bigrammed_texts_for_file = [" ".join(doc) for doc in bigrammed_corpus_list]
|
| 239 |
+
print("Bigram generation complete.")
|
| 240 |
+
|
| 241 |
+
# Build Vocabulary from OCTIS output (after bigrams)
|
| 242 |
+
# We need a flat list of all tokens to build the vocabulary
|
| 243 |
+
all_tokens = [token for doc in bigrammed_corpus_list for token in doc]
|
| 244 |
+
vocab = sorted(list(set(all_tokens))) # Sorted unique words form the vocabulary
|
| 245 |
+
word_to_id = {word: i for i, word in enumerate(vocab)}
|
| 246 |
+
|
| 247 |
+
# Create BOW matrix manually
|
| 248 |
+
print("Creating Bag-of-Words representations...")
|
| 249 |
+
rows, cols, data = [], [], []
|
| 250 |
+
for i, doc_tokens in enumerate(bigrammed_corpus_list):
|
| 251 |
+
doc_word_counts = {}
|
| 252 |
+
for token in doc_tokens:
|
| 253 |
+
if token in word_to_id: # Ensure token is in our final vocab
|
| 254 |
+
doc_word_counts[word_to_id[token]] = doc_word_counts.get(word_to_id[token], 0) + 1
|
| 255 |
+
for col_id, count in doc_word_counts.items():
|
| 256 |
+
rows.append(i)
|
| 257 |
+
cols.append(col_id)
|
| 258 |
+
data.append(count)
|
| 259 |
+
|
| 260 |
+
# Shape is (num_documents, vocab_size)
|
| 261 |
+
bow_matrix = scipy.sparse.csc_matrix((data, (rows, cols)), shape=(len(bigrammed_corpus_list), len(vocab)))
|
| 262 |
+
print("Bag-of-Words complete.")
|
| 263 |
+
|
| 264 |
+
# Handle partitioning if required
|
| 265 |
+
if self.use_partition:
|
| 266 |
+
num_docs = len(bigrammed_corpus_list)
|
| 267 |
+
train_size = int(0.8 * num_docs)
|
| 268 |
+
|
| 269 |
+
train_texts = bigrammed_texts_for_file[:train_size]
|
| 270 |
+
train_bow_matrix = bow_matrix[:train_size]
|
| 271 |
+
train_timestamps = filtered_timestamps[:train_size]
|
| 272 |
+
train_labels = processed_labels_octis[:train_size] if processed_labels_octis else []
|
| 273 |
+
|
| 274 |
+
test_texts = bigrammed_texts_for_file[train_size:]
|
| 275 |
+
test_bow_matrix = bow_matrix[train_size:]
|
| 276 |
+
test_timestamps = filtered_timestamps[train_size:]
|
| 277 |
+
test_labels = processed_labels_octis[train_size:] if processed_labels_octis else []
|
| 278 |
+
|
| 279 |
+
else:
|
| 280 |
+
train_texts = bigrammed_texts_for_file
|
| 281 |
+
train_bow_matrix = bow_matrix
|
| 282 |
+
train_timestamps = filtered_timestamps
|
| 283 |
+
train_labels = processed_labels_octis
|
| 284 |
+
test_texts = []
|
| 285 |
+
test_timestamps = []
|
| 286 |
+
test_labels = []
|
| 287 |
+
|
| 288 |
+
# Generate word embeddings using the provided function
|
| 289 |
+
word_embeddings = self._make_word_embeddings(vocab)
|
| 290 |
+
|
| 291 |
+
# Process timestamps to 0, 1, 2...T and create time2id.txt
|
| 292 |
+
print("Processing timestamps...")
|
| 293 |
+
unique_timestamps = sorted(list(set(train_timestamps + test_timestamps)))
|
| 294 |
+
time_to_id = {timestamp: i for i, timestamp in enumerate(unique_timestamps)}
|
| 295 |
+
|
| 296 |
+
train_times_ids = [time_to_id[ts] for ts in train_timestamps]
|
| 297 |
+
test_times_ids = [time_to_id[ts] for ts in test_timestamps] if self.use_partition else []
|
| 298 |
+
print("Timestamps processed.")
|
| 299 |
+
|
| 300 |
+
# Save files
|
| 301 |
+
print(f"Saving preprocessed files to {self.output_folder}...")
|
| 302 |
+
|
| 303 |
+
# 1. vocab.txt
|
| 304 |
+
with open(os.path.join(self.output_folder, "vocab.txt"), "w", encoding="utf-8") as f:
|
| 305 |
+
for word in vocab:
|
| 306 |
+
f.write(f"{word}\n")
|
| 307 |
+
|
| 308 |
+
# 2. train_texts.txt
|
| 309 |
+
train_text_path = os.path.join(self.output_folder, "train_texts.txt")
|
| 310 |
+
with open(train_text_path, "w", encoding="utf-8") as f:
|
| 311 |
+
for doc in train_texts:
|
| 312 |
+
f.write(f"{doc}\n")
|
| 313 |
+
|
| 314 |
+
# Save document length stats
|
| 315 |
+
doc_stats_path = os.path.join(self.output_folder, "length_stats.json")
|
| 316 |
+
self._save_doc_length_stats(train_text_path, doc_stats_path)
|
| 317 |
+
|
| 318 |
+
# 3. train_bow.npz
|
| 319 |
+
scipy.sparse.save_npz(os.path.join(self.output_folder, "train_bow.npz"), train_bow_matrix)
|
| 320 |
+
|
| 321 |
+
# 4. word_embeddings.npz
|
| 322 |
+
sparse_word_embeddings = scipy.sparse.csr_matrix(word_embeddings)
|
| 323 |
+
scipy.sparse.save_npz(os.path.join(self.output_folder, "word_embeddings.npz"), sparse_word_embeddings)
|
| 324 |
+
|
| 325 |
+
# 5. train_labels.txt (if labels exist)
|
| 326 |
+
if train_labels:
|
| 327 |
+
with open(os.path.join(self.output_folder, "train_labels.txt"), "w", encoding="utf-8") as f:
|
| 328 |
+
for label in train_labels:
|
| 329 |
+
f.write(f"{label}\n")
|
| 330 |
+
|
| 331 |
+
# 6. train_times.txt
|
| 332 |
+
with open(os.path.join(self.output_folder, "train_times.txt"), "w", encoding="utf-8") as f:
|
| 333 |
+
for time_id in train_times_ids:
|
| 334 |
+
f.write(f"{time_id}\n")
|
| 335 |
+
|
| 336 |
+
# Files for test set (if use_partition=True)
|
| 337 |
+
if self.use_partition:
|
| 338 |
+
# 7. test_bow.npz
|
| 339 |
+
scipy.sparse.save_npz(os.path.join(self.output_folder, "test_bow.npz"), test_bow_matrix)
|
| 340 |
+
|
| 341 |
+
# 8. test_texts.txt
|
| 342 |
+
with open(os.path.join(self.output_folder, "test_texts.txt"), "w", encoding="utf-8") as f:
|
| 343 |
+
for doc in test_texts:
|
| 344 |
+
f.write(f"{doc}\n")
|
| 345 |
+
|
| 346 |
+
# 9. test_labels.txt (if labels exist)
|
| 347 |
+
if test_labels:
|
| 348 |
+
with open(os.path.join(self.output_folder, "test_labels.txt"), "w", encoding="utf-8") as f:
|
| 349 |
+
for label in test_labels:
|
| 350 |
+
f.write(f"{label}\n")
|
| 351 |
+
|
| 352 |
+
# 10. test_times.txt
|
| 353 |
+
with open(os.path.join(self.output_folder, "test_times.txt"), "w", encoding="utf-8") as f:
|
| 354 |
+
for time_id in test_times_ids:
|
| 355 |
+
f.write(f"{time_id}\n")
|
| 356 |
+
|
| 357 |
+
# 11. time2id.txt
|
| 358 |
+
sorted_time_to_id = OrderedDict(sorted(time_to_id.items(), key=lambda item: item[1]))
|
| 359 |
+
with open(os.path.join(self.output_folder, "time2id.txt"), "w", encoding="utf-8") as f:
|
| 360 |
+
json.dump(sorted_time_to_id, f, indent=4)
|
| 361 |
+
|
| 362 |
+
print("All files saved successfully.")
|
backend/datasets/utils/_utils.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from backend.datasets.data import file_utils
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def get_top_words(beta, vocab, num_top_words, verbose=False):
|
| 6 |
+
topic_str_list = list()
|
| 7 |
+
for i, topic_dist in enumerate(beta):
|
| 8 |
+
topic_words = np.array(vocab)[np.argsort(topic_dist)][:-(num_top_words + 1):-1]
|
| 9 |
+
topic_str = ' '.join(topic_words)
|
| 10 |
+
topic_str_list.append(topic_str)
|
| 11 |
+
if verbose:
|
| 12 |
+
print('Topic {}: {}'.format(i, topic_str))
|
| 13 |
+
|
| 14 |
+
return topic_str_list
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def get_stopwords_set(stopwords=[]):
|
| 18 |
+
from backend.datasets.data.download import download_dataset
|
| 19 |
+
|
| 20 |
+
if stopwords == 'English':
|
| 21 |
+
from gensim.parsing.preprocessing import STOPWORDS as stopwords
|
| 22 |
+
|
| 23 |
+
elif stopwords in ['mallet', 'snowball']:
|
| 24 |
+
download_dataset('stopwords', cache_path='./')
|
| 25 |
+
path = f'./stopwords/{stopwords}_stopwords.txt'
|
| 26 |
+
stopwords = file_utils.read_text(path)
|
| 27 |
+
|
| 28 |
+
stopword_set = frozenset(stopwords)
|
| 29 |
+
|
| 30 |
+
return stopword_set
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
if __name__ == '__main__':
|
| 34 |
+
print(list(get_stopwords_set('English'))[:10])
|
| 35 |
+
print(list(get_stopwords_set('mallet'))[:10])
|
| 36 |
+
print(list(get_stopwords_set('snowball'))[:10])
|
| 37 |
+
print(list(get_stopwords_set())[:10])
|
backend/datasets/utils/logger.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class Logger:
|
| 5 |
+
def __init__(self, level):
|
| 6 |
+
self.logger = logging.getLogger('TopMost')
|
| 7 |
+
self.set_level(level)
|
| 8 |
+
self._add_handler()
|
| 9 |
+
self.logger.propagate = False
|
| 10 |
+
|
| 11 |
+
def info(self, message):
|
| 12 |
+
self.logger.info(f"{message}")
|
| 13 |
+
|
| 14 |
+
def warning(self, message):
|
| 15 |
+
self.logger.warning(f"WARNING: {message}")
|
| 16 |
+
|
| 17 |
+
def set_level(self, level):
|
| 18 |
+
levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
|
| 19 |
+
if level in levels:
|
| 20 |
+
self.logger.setLevel(level)
|
| 21 |
+
|
| 22 |
+
def _add_handler(self):
|
| 23 |
+
sh = logging.StreamHandler()
|
| 24 |
+
sh.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(message)s'))
|
| 25 |
+
self.logger.addHandler(sh)
|
| 26 |
+
|
| 27 |
+
# Remove duplicate handlers
|
| 28 |
+
if len(self.logger.handlers) > 1:
|
| 29 |
+
self.logger.handlers = [self.logger.handlers[0]]
|
backend/evaluation/CoherenceModel_ttc.py
ADDED
|
@@ -0,0 +1,862 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import multiprocessing as mp
|
| 3 |
+
from collections import namedtuple
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from gensim import interfaces, matutils
|
| 8 |
+
from gensim import utils
|
| 9 |
+
from gensim.topic_coherence import (
|
| 10 |
+
segmentation, probability_estimation,
|
| 11 |
+
direct_confirmation_measure, indirect_confirmation_measure,
|
| 12 |
+
aggregation,
|
| 13 |
+
)
|
| 14 |
+
from gensim.topic_coherence.probability_estimation import unique_ids_from_segments
|
| 15 |
+
|
| 16 |
+
# Set up logging for this module
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
# Define sets for categorizing coherence measures based on their probability estimation method
|
| 20 |
+
BOOLEAN_DOCUMENT_BASED = {'u_mass'}
|
| 21 |
+
SLIDING_WINDOW_BASED = {'c_v', 'c_uci', 'c_npmi', 'c_w2v'}
|
| 22 |
+
|
| 23 |
+
# Create a namedtuple to define the structure of a coherence measure pipeline
|
| 24 |
+
# Each pipeline consists of a segmentation (seg), probability estimation (prob),
|
| 25 |
+
# confirmation measure (conf), and aggregation (aggr) function.
|
| 26 |
+
_make_pipeline = namedtuple('Coherence_Measure', 'seg, prob, conf, aggr')
|
| 27 |
+
|
| 28 |
+
# Define the supported coherence measures and their respective pipeline components
|
| 29 |
+
COHERENCE_MEASURES = {
|
| 30 |
+
'u_mass': _make_pipeline(
|
| 31 |
+
segmentation.s_one_pre,
|
| 32 |
+
probability_estimation.p_boolean_document,
|
| 33 |
+
direct_confirmation_measure.log_conditional_probability,
|
| 34 |
+
aggregation.arithmetic_mean
|
| 35 |
+
),
|
| 36 |
+
'c_v': _make_pipeline(
|
| 37 |
+
segmentation.s_one_set,
|
| 38 |
+
probability_estimation.p_boolean_sliding_window,
|
| 39 |
+
indirect_confirmation_measure.cosine_similarity,
|
| 40 |
+
aggregation.arithmetic_mean
|
| 41 |
+
),
|
| 42 |
+
'c_w2v': _make_pipeline(
|
| 43 |
+
segmentation.s_one_set,
|
| 44 |
+
probability_estimation.p_word2vec,
|
| 45 |
+
indirect_confirmation_measure.word2vec_similarity,
|
| 46 |
+
aggregation.arithmetic_mean
|
| 47 |
+
),
|
| 48 |
+
'c_uci': _make_pipeline(
|
| 49 |
+
segmentation.s_one_one,
|
| 50 |
+
probability_estimation.p_boolean_sliding_window,
|
| 51 |
+
direct_confirmation_measure.log_ratio_measure,
|
| 52 |
+
aggregation.arithmetic_mean
|
| 53 |
+
),
|
| 54 |
+
'c_npmi': _make_pipeline(
|
| 55 |
+
segmentation.s_one_one,
|
| 56 |
+
probability_estimation.p_boolean_sliding_window,
|
| 57 |
+
direct_confirmation_measure.log_ratio_measure,
|
| 58 |
+
aggregation.arithmetic_mean
|
| 59 |
+
),
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
# Define default sliding window sizes for different coherence measures
|
| 63 |
+
SLIDING_WINDOW_SIZES = {
|
| 64 |
+
'c_v': 110,
|
| 65 |
+
'c_w2v': 5,
|
| 66 |
+
'c_uci': 10,
|
| 67 |
+
'c_npmi': 10,
|
| 68 |
+
'u_mass': None # u_mass does not use a sliding window
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class CoherenceModel_ttc(interfaces.TransformationABC):
|
| 73 |
+
"""Objects of this class allow for building and maintaining a model for topic coherence.
|
| 74 |
+
|
| 75 |
+
Examples
|
| 76 |
+
---------
|
| 77 |
+
One way of using this feature is through providing a trained topic model. A dictionary has to be explicitly provided
|
| 78 |
+
if the model does not contain a dictionary already
|
| 79 |
+
|
| 80 |
+
.. sourcecode:: pycon
|
| 81 |
+
|
| 82 |
+
>>> from gensim.test.utils import common_corpus, common_dictionary
|
| 83 |
+
>>> from gensim.models.ldamodel import LdaModel
|
| 84 |
+
>>> # Assuming CoherenceModel_ttc is imported or defined in the current scope
|
| 85 |
+
>>> # from your_module import CoherenceModel_ttc # if saved in a file
|
| 86 |
+
>>>
|
| 87 |
+
>>> model = LdaModel(common_corpus, 5, common_dictionary)
|
| 88 |
+
>>>
|
| 89 |
+
>>> cm = CoherenceModel_ttc(model=model, corpus=common_corpus, coherence='u_mass')
|
| 90 |
+
>>> coherence = cm.get_coherence() # get coherence value
|
| 91 |
+
|
| 92 |
+
Another way of using this feature is through providing tokenized topics such as:
|
| 93 |
+
|
| 94 |
+
.. sourcecode:: pycon
|
| 95 |
+
|
| 96 |
+
>>> from gensim.test.utils import common_corpus, common_dictionary
|
| 97 |
+
>>> # Assuming CoherenceModel_ttc is imported or defined in the current scope
|
| 98 |
+
>>> # from your_module import CoherenceModel_ttc # if saved in a file
|
| 99 |
+
>>> topics = [
|
| 100 |
+
... ['human', 'computer', 'system', 'interface'],
|
| 101 |
+
... ['graph', 'minors', 'trees', 'eps']
|
| 102 |
+
... ]
|
| 103 |
+
>>>
|
| 104 |
+
>>> cm = CoherenceModel_ttc(topics=topics, corpus=common_corpus, dictionary=common_dictionary, coherence='u_mass')
|
| 105 |
+
>>> coherence = cm.get_coherence() # get coherence value
|
| 106 |
+
|
| 107 |
+
"""
|
| 108 |
+
def __init__(self, model=None, topics=None, texts=None, corpus=None, dictionary=None,
|
| 109 |
+
window_size=None, keyed_vectors=None, coherence='c_v', topn=20, processes=-1):
|
| 110 |
+
"""
|
| 111 |
+
Initializes the CoherenceModel_ttc.
|
| 112 |
+
|
| 113 |
+
Parameters
|
| 114 |
+
----------
|
| 115 |
+
model : :class:`~gensim.models.basemodel.BaseTopicModel`, optional
|
| 116 |
+
Pre-trained topic model. Should be provided if `topics` is not provided.
|
| 117 |
+
Supports models that implement the `get_topics` method.
|
| 118 |
+
topics : list of list of str, optional
|
| 119 |
+
List of tokenized topics. If provided, `dictionary` must also be provided.
|
| 120 |
+
texts : list of list of str, optional
|
| 121 |
+
Tokenized texts, needed for coherence models that use sliding window based (e.g., `c_v`, `c_uci`, `c_npmi`).
|
| 122 |
+
corpus : iterable of list of (int, number), optional
|
| 123 |
+
Corpus in Bag-of-Words format.
|
| 124 |
+
dictionary : :class:`~gensim.corpora.dictionary.Dictionary`, optional
|
| 125 |
+
Gensim dictionary mapping of id word to create corpus.
|
| 126 |
+
If `model.id2word` is present and `dictionary` is None, `model.id2word` will be used.
|
| 127 |
+
window_size : int, optional
|
| 128 |
+
The size of the window to be used for coherence measures using boolean sliding window as their
|
| 129 |
+
probability estimator. For 'u_mass' this doesn't matter.
|
| 130 |
+
If None, default window sizes from `SLIDING_WINDOW_SIZES` are used.
|
| 131 |
+
keyed_vectors : :class:`~gensim.models.keyedvectors.KeyedVectors`, optional
|
| 132 |
+
Pre-trained word embeddings (e.g., Word2Vec model) for 'c_w2v' coherence.
|
| 133 |
+
coherence : {'u_mass', 'c_v', 'c_uci', 'c_npmi', 'c_w2v'}, optional
|
| 134 |
+
Coherence measure to be used.
|
| 135 |
+
'u_mass' requires `corpus` (or `texts` which will be converted to corpus).
|
| 136 |
+
'c_v', 'c_uci', 'c_npmi', 'c_w2v' require `texts`.
|
| 137 |
+
topn : int, optional
|
| 138 |
+
Integer corresponding to the number of top words to be extracted from each topic. Defaults to 20.
|
| 139 |
+
processes : int, optional
|
| 140 |
+
Number of processes to use for probability estimation phase. Any value less than 1 will be interpreted as
|
| 141 |
+
`num_cpus - 1`. Defaults to -1.
|
| 142 |
+
"""
|
| 143 |
+
# Ensure either a model or explicit topics are provided
|
| 144 |
+
if model is None and topics is None:
|
| 145 |
+
raise ValueError("One of 'model' or 'topics' has to be provided.")
|
| 146 |
+
# If topics are provided, a dictionary is mandatory to convert tokens to IDs
|
| 147 |
+
elif topics is not None and dictionary is None:
|
| 148 |
+
raise ValueError("Dictionary has to be provided if 'topics' are to be used.")
|
| 149 |
+
|
| 150 |
+
self.keyed_vectors = keyed_vectors
|
| 151 |
+
# Ensure a data source (keyed_vectors, texts, or corpus) is provided for coherence calculation
|
| 152 |
+
if keyed_vectors is None and texts is None and corpus is None:
|
| 153 |
+
raise ValueError("One of 'texts', 'corpus', or 'keyed_vectors' has to be provided.")
|
| 154 |
+
|
| 155 |
+
# Determine the dictionary to use
|
| 156 |
+
if dictionary is None:
|
| 157 |
+
# If no explicit dictionary, try to use the model's dictionary
|
| 158 |
+
if isinstance(model.id2word, utils.FakeDict):
|
| 159 |
+
# If model's id2word is a FakeDict, it means no proper dictionary is associated
|
| 160 |
+
raise ValueError(
|
| 161 |
+
"The associated dictionary should be provided with the corpus or 'id2word'"
|
| 162 |
+
" for topic model should be set as the associated dictionary.")
|
| 163 |
+
else:
|
| 164 |
+
self.dictionary = model.id2word
|
| 165 |
+
else:
|
| 166 |
+
self.dictionary = dictionary
|
| 167 |
+
|
| 168 |
+
# Store coherence type and window size
|
| 169 |
+
self.coherence = coherence
|
| 170 |
+
self.window_size = window_size
|
| 171 |
+
if self.window_size is None:
|
| 172 |
+
# Use default window size if not specified
|
| 173 |
+
self.window_size = SLIDING_WINDOW_SIZES[self.coherence]
|
| 174 |
+
|
| 175 |
+
# Store texts and corpus
|
| 176 |
+
self.texts = texts
|
| 177 |
+
self.corpus = corpus
|
| 178 |
+
|
| 179 |
+
# Validate inputs based on coherence type
|
| 180 |
+
if coherence in BOOLEAN_DOCUMENT_BASED:
|
| 181 |
+
# For document-based measures (e.g., u_mass), corpus is preferred
|
| 182 |
+
if utils.is_corpus(corpus)[0]:
|
| 183 |
+
self.corpus = corpus
|
| 184 |
+
elif self.texts is not None:
|
| 185 |
+
# If texts are provided, convert them to corpus format
|
| 186 |
+
self.corpus = [self.dictionary.doc2bow(text) for text in self.texts]
|
| 187 |
+
else:
|
| 188 |
+
raise ValueError(
|
| 189 |
+
"Either 'corpus' with 'dictionary' or 'texts' should "
|
| 190 |
+
"be provided for %s coherence." % coherence)
|
| 191 |
+
|
| 192 |
+
elif coherence == 'c_w2v' and keyed_vectors is not None:
|
| 193 |
+
# For c_w2v, keyed_vectors are needed
|
| 194 |
+
pass
|
| 195 |
+
elif coherence in SLIDING_WINDOW_BASED:
|
| 196 |
+
# For sliding window-based measures, texts are required
|
| 197 |
+
if self.texts is None:
|
| 198 |
+
raise ValueError("'texts' should be provided for %s coherence." % coherence)
|
| 199 |
+
else:
|
| 200 |
+
# Raise error if coherence type is not supported
|
| 201 |
+
raise ValueError("%s coherence is not currently supported." % coherence)
|
| 202 |
+
|
| 203 |
+
self._topn = topn
|
| 204 |
+
self._model = model
|
| 205 |
+
self._accumulator = None # Cached accumulator for probability estimation
|
| 206 |
+
self._topics = None # Store topics internally
|
| 207 |
+
self.topics = topics # Call the setter to initialize topics and accumulator state
|
| 208 |
+
|
| 209 |
+
# Determine the number of processes to use for parallelization
|
| 210 |
+
self.processes = processes if processes >= 1 else max(1, mp.cpu_count() - 1)
|
| 211 |
+
|
| 212 |
+
@classmethod
|
| 213 |
+
def for_models(cls, models, dictionary, topn=20, **kwargs):
|
| 214 |
+
"""
|
| 215 |
+
Initialize a CoherenceModel_ttc with estimated probabilities for all of the given models.
|
| 216 |
+
This method extracts topics from each model and then uses `for_topics`.
|
| 217 |
+
|
| 218 |
+
Parameters
|
| 219 |
+
----------
|
| 220 |
+
models : list of :class:`~gensim.models.basemodel.BaseTopicModel`
|
| 221 |
+
List of models to evaluate coherence of. Each model should implement
|
| 222 |
+
the `get_topics` method.
|
| 223 |
+
dictionary : :class:`~gensim.corpora.dictionary.Dictionary`
|
| 224 |
+
Gensim dictionary mapping of id word.
|
| 225 |
+
topn : int, optional
|
| 226 |
+
Integer corresponding to the number of top words to be extracted from each topic. Defaults to 20.
|
| 227 |
+
kwargs : object
|
| 228 |
+
Additional arguments passed to the `CoherenceModel_ttc` constructor (e.g., `corpus`, `texts`, `coherence`).
|
| 229 |
+
|
| 230 |
+
Returns
|
| 231 |
+
-------
|
| 232 |
+
:class:`~gensim.models.coherencemodel.CoherenceModel`
|
| 233 |
+
CoherenceModel_ttc instance with estimated probabilities for all given models.
|
| 234 |
+
|
| 235 |
+
Example
|
| 236 |
+
-------
|
| 237 |
+
.. sourcecode:: pycon
|
| 238 |
+
|
| 239 |
+
>>> from gensim.test.utils import common_corpus, common_dictionary
|
| 240 |
+
>>> from gensim.models.ldamodel import LdaModel
|
| 241 |
+
>>> # from your_module import CoherenceModel_ttc
|
| 242 |
+
>>>
|
| 243 |
+
>>> m1 = LdaModel(common_corpus, 3, common_dictionary)
|
| 244 |
+
>>> m2 = LdaModel(common_corpus, 5, common_dictionary)
|
| 245 |
+
>>>
|
| 246 |
+
>>> cm = CoherenceModel_ttc.for_models([m1, m2], common_dictionary, corpus=common_corpus, coherence='u_mass')
|
| 247 |
+
>>> # To get coherences for each model:
|
| 248 |
+
>>> # model_coherences = cm.compare_model_topics([
|
| 249 |
+
>>> # CoherenceModel_ttc._get_topics_from_model(m1, topn=cm.topn),
|
| 250 |
+
>>> # CoherenceModel_ttc._get_topics_from_model(m2, topn=cm.topn)
|
| 251 |
+
>>> # ])
|
| 252 |
+
"""
|
| 253 |
+
# Extract top words as lists for each model's topics
|
| 254 |
+
topics = [cls.top_topics_as_word_lists(model, dictionary, topn) for model in models]
|
| 255 |
+
kwargs['dictionary'] = dictionary
|
| 256 |
+
kwargs['topn'] = topn
|
| 257 |
+
# Use for_topics to initialize the coherence model with these topics
|
| 258 |
+
return cls.for_topics(topics, **kwargs)
|
| 259 |
+
|
| 260 |
+
@staticmethod
|
| 261 |
+
def top_topics_as_word_lists(model, dictionary, topn=20):
|
| 262 |
+
"""
|
| 263 |
+
Get `topn` topics from a model as lists of words.
|
| 264 |
+
|
| 265 |
+
Parameters
|
| 266 |
+
----------
|
| 267 |
+
model : :class:`~gensim.models.basemodel.BaseTopicModel`
|
| 268 |
+
Pre-trained topic model.
|
| 269 |
+
dictionary : :class:`~gensim.corpora.dictionary.Dictionary`
|
| 270 |
+
Gensim dictionary mapping of id word.
|
| 271 |
+
topn : int, optional
|
| 272 |
+
Integer corresponding to the number of top words to be extracted from each topic. Defaults to 20.
|
| 273 |
+
|
| 274 |
+
Returns
|
| 275 |
+
-------
|
| 276 |
+
list of list of str
|
| 277 |
+
Top topics in list-of-list-of-words format.
|
| 278 |
+
"""
|
| 279 |
+
# Ensure id2token mapping exists in the dictionary
|
| 280 |
+
if not dictionary.id2token:
|
| 281 |
+
dictionary.id2token = {v: k for k, v in dictionary.token2id.items()}
|
| 282 |
+
|
| 283 |
+
str_topics = []
|
| 284 |
+
for topic_distribution in model.get_topics():
|
| 285 |
+
# Get the indices of the topN words based on their probabilities
|
| 286 |
+
bestn_indices = matutils.argsort(topic_distribution, topn=topn, reverse=True)
|
| 287 |
+
# Convert word IDs back to words using the dictionary
|
| 288 |
+
best_words = [dictionary.id2token[_id] for _id in bestn_indices]
|
| 289 |
+
str_topics.append(best_words)
|
| 290 |
+
return str_topics
|
| 291 |
+
|
| 292 |
+
@classmethod
|
| 293 |
+
def for_topics(cls, topics_as_topn_terms, **kwargs):
|
| 294 |
+
"""
|
| 295 |
+
Initialize a CoherenceModel_ttc with estimated probabilities for all of the given topics.
|
| 296 |
+
This is useful when you have raw topics (list of lists of words) and not a Gensim model object.
|
| 297 |
+
|
| 298 |
+
Parameters
|
| 299 |
+
----------
|
| 300 |
+
topics_as_topn_terms : list of list of str
|
| 301 |
+
Each element in the top-level list should be a list of top-N words, one per topic.
|
| 302 |
+
For example: `[['word1', 'word2'], ['word3', 'word4']]`.
|
| 303 |
+
|
| 304 |
+
Returns
|
| 305 |
+
-------
|
| 306 |
+
:class:`~gensim.models.coherencemodel.CoherenceModel`
|
| 307 |
+
CoherenceModel_ttc with estimated probabilities for the given topics.
|
| 308 |
+
"""
|
| 309 |
+
if not topics_as_topn_terms:
|
| 310 |
+
raise ValueError("len(topics_as_topn_terms) must be > 0.")
|
| 311 |
+
if any(len(topic_list) == 0 for topic_list in topics_as_topn_terms):
|
| 312 |
+
raise ValueError("Found an empty topic listing in `topics_as_topn_terms`.")
|
| 313 |
+
|
| 314 |
+
# Determine the maximum 'topn' value among the provided topics
|
| 315 |
+
# This will be used to initialize the CoherenceModel_ttc correctly for probability estimation
|
| 316 |
+
actual_topn_in_data = 0
|
| 317 |
+
for topic_list in topics_as_topn_terms:
|
| 318 |
+
for topic in topic_list:
|
| 319 |
+
actual_topn_in_data = max(actual_topn_in_data, len(topic))
|
| 320 |
+
|
| 321 |
+
# Use the provided 'topn' from kwargs, or the determined 'actual_topn_in_data',
|
| 322 |
+
# ensuring it's not greater than the actual data available.
|
| 323 |
+
# This allows for precomputing probabilities for a wider set of words if needed.
|
| 324 |
+
topn_for_prob_estimation = min(kwargs.pop('topn', actual_topn_in_data), actual_topn_in_data)
|
| 325 |
+
|
| 326 |
+
# Flatten all topics into a single "super topic" for initial probability estimation.
|
| 327 |
+
# This ensures that all words relevant to *any* topic in the comparison set
|
| 328 |
+
# are included in the accumulator.
|
| 329 |
+
super_topic = utils.flatten(topics_as_topn_terms)
|
| 330 |
+
|
| 331 |
+
logger.info(
|
| 332 |
+
"Number of relevant terms for all %d models (or topic sets): %d",
|
| 333 |
+
len(topics_as_topn_terms), len(super_topic))
|
| 334 |
+
|
| 335 |
+
# Initialize CoherenceModel_ttc with the super topic to pre-estimate probabilities
|
| 336 |
+
# for all relevant words across all models.
|
| 337 |
+
# We pass `topics=[super_topic]` and `topn=len(super_topic)` to ensure all words
|
| 338 |
+
# are considered during the probability estimation phase.
|
| 339 |
+
cm = CoherenceModel_ttc(topics=[super_topic], topn=len(super_topic), **kwargs)
|
| 340 |
+
cm.estimate_probabilities() # Perform the actual probability estimation
|
| 341 |
+
|
| 342 |
+
# After estimation, set the 'topn' back to the desired value for coherence calculation.
|
| 343 |
+
cm.topn = topn_for_prob_estimation
|
| 344 |
+
return cm
|
| 345 |
+
|
| 346 |
+
def __str__(self):
|
| 347 |
+
"""Returns a string representation of the coherence measure pipeline."""
|
| 348 |
+
return str(self.measure)
|
| 349 |
+
|
| 350 |
+
@property
|
| 351 |
+
def model(self):
|
| 352 |
+
"""
|
| 353 |
+
Get the current topic model used by the instance.
|
| 354 |
+
|
| 355 |
+
Returns
|
| 356 |
+
-------
|
| 357 |
+
:class:`~gensim.models.basemodel.BaseTopicModel`
|
| 358 |
+
The currently set topic model.
|
| 359 |
+
"""
|
| 360 |
+
return self._model
|
| 361 |
+
|
| 362 |
+
@model.setter
|
| 363 |
+
def model(self, model):
|
| 364 |
+
"""
|
| 365 |
+
Set the topic model for the instance. When a new model is set,
|
| 366 |
+
it triggers an update of the internal topics and checks if the accumulator needs recomputing.
|
| 367 |
+
|
| 368 |
+
Parameters
|
| 369 |
+
----------
|
| 370 |
+
model : :class:`~gensim.models.basemodel.BaseTopicModel`
|
| 371 |
+
The new topic model to set.
|
| 372 |
+
"""
|
| 373 |
+
self._model = model
|
| 374 |
+
if model is not None:
|
| 375 |
+
new_topics = self._get_topics() # Get topics from the new model
|
| 376 |
+
self._update_accumulator(new_topics) # Check and update accumulator if needed
|
| 377 |
+
self._topics = new_topics # Store the new topics
|
| 378 |
+
|
| 379 |
+
@property
|
| 380 |
+
def topn(self):
|
| 381 |
+
"""
|
| 382 |
+
Get the number of top words (`_topn`) used for coherence calculation.
|
| 383 |
+
|
| 384 |
+
Returns
|
| 385 |
+
-------
|
| 386 |
+
int
|
| 387 |
+
The number of top words.
|
| 388 |
+
"""
|
| 389 |
+
return self._topn
|
| 390 |
+
|
| 391 |
+
@topn.setter
|
| 392 |
+
def topn(self, topn):
|
| 393 |
+
"""
|
| 394 |
+
Set the number of top words (`_topn`) to consider for coherence calculation.
|
| 395 |
+
If the new `topn` requires more words than currently loaded topics, and a model is available,
|
| 396 |
+
it will attempt to re-extract topics from the model.
|
| 397 |
+
|
| 398 |
+
Parameters
|
| 399 |
+
----------
|
| 400 |
+
topn : int
|
| 401 |
+
The new number of top words.
|
| 402 |
+
"""
|
| 403 |
+
# Get the length of the first topic to check current topic length
|
| 404 |
+
current_topic_length = len(self._topics[0])
|
| 405 |
+
# Determine if the new 'topn' requires more words than currently available in topics
|
| 406 |
+
requires_expansion = current_topic_length < topn
|
| 407 |
+
|
| 408 |
+
if self.model is not None:
|
| 409 |
+
self._topn = topn
|
| 410 |
+
if requires_expansion:
|
| 411 |
+
# If expansion is needed and a model is available, re-extract topics from the model.
|
| 412 |
+
# This call to the setter property `self.model = self._model` effectively re-runs
|
| 413 |
+
# the logic that extracts topics and updates the accumulator based on the new `_topn`.
|
| 414 |
+
self.model = self._model
|
| 415 |
+
else:
|
| 416 |
+
# If no model is available and expansion is required, raise an error
|
| 417 |
+
if requires_expansion:
|
| 418 |
+
raise ValueError("Model unavailable and topic sizes are less than topn=%d" % topn)
|
| 419 |
+
self._topn = topn # Topics will be truncated by the `topics` getter if needed
|
| 420 |
+
|
| 421 |
+
@property
|
| 422 |
+
def measure(self):
|
| 423 |
+
"""
|
| 424 |
+
Returns the namedtuple representing the coherence pipeline functions
|
| 425 |
+
(segmentation, probability estimation, confirmation, aggregation)
|
| 426 |
+
based on the `self.coherence` type.
|
| 427 |
+
|
| 428 |
+
Returns
|
| 429 |
+
-------
|
| 430 |
+
namedtuple
|
| 431 |
+
Pipeline that contains needed functions/method for calculating coherence.
|
| 432 |
+
"""
|
| 433 |
+
return COHERENCE_MEASURES[self.coherence]
|
| 434 |
+
|
| 435 |
+
@property
|
| 436 |
+
def topics(self):
|
| 437 |
+
"""
|
| 438 |
+
Get the current topics. If the internally stored topics have more words
|
| 439 |
+
than `self._topn`, they are truncated to `self._topn` words.
|
| 440 |
+
|
| 441 |
+
Returns
|
| 442 |
+
-------
|
| 443 |
+
list of list of str
|
| 444 |
+
Topics as lists of word tokens.
|
| 445 |
+
"""
|
| 446 |
+
# If the stored topics contain more words than `_topn`, truncate them
|
| 447 |
+
if len(self._topics[0]) > self._topn:
|
| 448 |
+
return [topic[:self._topn] for topic in self._topics]
|
| 449 |
+
else:
|
| 450 |
+
return self._topics
|
| 451 |
+
|
| 452 |
+
@topics.setter
|
| 453 |
+
def topics(self, topics):
|
| 454 |
+
"""
|
| 455 |
+
Set the topics for the instance. This method converts topic words to their
|
| 456 |
+
corresponding dictionary IDs and updates the accumulator state.
|
| 457 |
+
|
| 458 |
+
Parameters
|
| 459 |
+
----------
|
| 460 |
+
topics : list of list of str or list of list of int
|
| 461 |
+
Topics, either as lists of word tokens or lists of word IDs.
|
| 462 |
+
"""
|
| 463 |
+
if topics is not None:
|
| 464 |
+
new_topics = []
|
| 465 |
+
for topic in topics:
|
| 466 |
+
# Ensure topic elements are converted to dictionary IDs (numpy array for efficiency)
|
| 467 |
+
topic_token_ids = self._ensure_elements_are_ids(topic)
|
| 468 |
+
new_topics.append(topic_token_ids)
|
| 469 |
+
|
| 470 |
+
if self.model is not None:
|
| 471 |
+
# Warn if both model and explicit topics are set, as they might be inconsistent
|
| 472 |
+
logger.warning(
|
| 473 |
+
"The currently set model '%s' may be inconsistent with the newly set topics",
|
| 474 |
+
self.model)
|
| 475 |
+
elif self.model is not None:
|
| 476 |
+
# If topics are None but a model exists, extract topics from the model
|
| 477 |
+
new_topics = self._get_topics()
|
| 478 |
+
logger.debug("Setting topics to those of the model: %s", self.model)
|
| 479 |
+
else:
|
| 480 |
+
new_topics = None
|
| 481 |
+
|
| 482 |
+
# Check if the accumulator needs to be recomputed based on the new topics
|
| 483 |
+
self._update_accumulator(new_topics)
|
| 484 |
+
self._topics = new_topics # Store the (ID-converted) topics
|
| 485 |
+
|
| 486 |
+
def _ensure_elements_are_ids(self, topic):
|
| 487 |
+
"""
|
| 488 |
+
Internal helper to ensure that topic elements are converted to dictionary IDs.
|
| 489 |
+
Handles cases where input topic might be tokens or already IDs.
|
| 490 |
+
|
| 491 |
+
Parameters
|
| 492 |
+
----------
|
| 493 |
+
topic : list of str or list of int
|
| 494 |
+
A single topic, either as a list of word tokens or word IDs.
|
| 495 |
+
|
| 496 |
+
Returns
|
| 497 |
+
-------
|
| 498 |
+
:class:`numpy.ndarray`
|
| 499 |
+
A numpy array of word IDs for the topic.
|
| 500 |
+
|
| 501 |
+
Raises
|
| 502 |
+
------
|
| 503 |
+
KeyError
|
| 504 |
+
If a token is not found in the dictionary or an ID is not a valid key in id2token.
|
| 505 |
+
"""
|
| 506 |
+
try:
|
| 507 |
+
# Try to convert tokens to IDs. This is the common case if `topic` contains strings.
|
| 508 |
+
return np.array([self.dictionary.token2id[token] for token in topic if token in self.dictionary.token2id])
|
| 509 |
+
except KeyError:
|
| 510 |
+
# If `KeyError` occurs, assume `topic` might already be a list of IDs.
|
| 511 |
+
# Attempt to convert IDs to tokens and then back to IDs, ensuring they are valid dictionary entries.
|
| 512 |
+
# This handles cases where `topic` might contain integer IDs that are not present in the dictionary.
|
| 513 |
+
try:
|
| 514 |
+
# Convert IDs to tokens (via id2token) and then tokens to IDs (via token2id)
|
| 515 |
+
# This filters out invalid IDs.
|
| 516 |
+
return np.array([self.dictionary.token2id[self.dictionary.id2token[_id]]
|
| 517 |
+
for _id in topic if _id in self.dictionary])
|
| 518 |
+
except KeyError:
|
| 519 |
+
raise ValueError("Unable to interpret topic as either a list of tokens or a list of valid IDs within the dictionary.")
|
| 520 |
+
|
| 521 |
+
def _update_accumulator(self, new_topics):
|
| 522 |
+
"""
|
| 523 |
+
Internal helper to determine if the cached `_accumulator` (probability statistics)
|
| 524 |
+
needs to be wiped and recomputed due to changes in topics.
|
| 525 |
+
"""
|
| 526 |
+
if self._relevant_ids_will_differ(new_topics):
|
| 527 |
+
logger.debug("Wiping cached accumulator since it does not contain all relevant ids.")
|
| 528 |
+
self._accumulator = None
|
| 529 |
+
|
| 530 |
+
def _relevant_ids_will_differ(self, new_topics):
|
| 531 |
+
"""
|
| 532 |
+
Internal helper to check if the set of unique word IDs relevant to the new topics
|
| 533 |
+
is different from the IDs already covered by the current accumulator.
|
| 534 |
+
|
| 535 |
+
Parameters
|
| 536 |
+
----------
|
| 537 |
+
new_topics : list of list of int
|
| 538 |
+
The new set of topics (as word IDs).
|
| 539 |
+
|
| 540 |
+
Returns
|
| 541 |
+
-------
|
| 542 |
+
bool
|
| 543 |
+
True if the relevant IDs will differ, False otherwise.
|
| 544 |
+
"""
|
| 545 |
+
if self._accumulator is None or not self._topics_differ(new_topics):
|
| 546 |
+
return False
|
| 547 |
+
|
| 548 |
+
# Get unique IDs from the segmented new topics
|
| 549 |
+
new_set = unique_ids_from_segments(self.measure.seg(new_topics))
|
| 550 |
+
# Check if the current accumulator's relevant IDs are a superset of the new set.
|
| 551 |
+
# If not, it means the new topics introduce words not covered, so the accumulator needs updating.
|
| 552 |
+
return not self._accumulator.relevant_ids.issuperset(new_set)
|
| 553 |
+
|
| 554 |
+
def _topics_differ(self, new_topics):
|
| 555 |
+
"""
|
| 556 |
+
Internal helper to check if the new topics are different from the currently stored topics.
|
| 557 |
+
|
| 558 |
+
Parameters
|
| 559 |
+
----------
|
| 560 |
+
new_topics : list of list of int
|
| 561 |
+
The new set of topics (as word IDs).
|
| 562 |
+
|
| 563 |
+
Returns
|
| 564 |
+
-------
|
| 565 |
+
bool
|
| 566 |
+
True if topics are different, False otherwise.
|
| 567 |
+
"""
|
| 568 |
+
# Compare topic arrays using numpy.array_equal for efficient comparison
|
| 569 |
+
return (new_topics is not None
|
| 570 |
+
and self._topics is not None
|
| 571 |
+
and not np.array_equal(new_topics, self._topics))
|
| 572 |
+
|
| 573 |
+
def _get_topics(self):
|
| 574 |
+
"""
|
| 575 |
+
Internal helper function to extract top words (as IDs) from a trained topic model.
|
| 576 |
+
"""
|
| 577 |
+
return self._get_topics_from_model(self.model, self.topn)
|
| 578 |
+
|
| 579 |
+
@staticmethod
|
| 580 |
+
def _get_topics_from_model(model, topn):
|
| 581 |
+
"""
|
| 582 |
+
Internal static method to extract top `topn` words (as IDs) from a trained topic model.
|
| 583 |
+
|
| 584 |
+
Parameters
|
| 585 |
+
----------
|
| 586 |
+
model : :class:`~gensim.models.basemodel.BaseTopicModel`
|
| 587 |
+
Pre-trained topic model (must implement `get_topics` method).
|
| 588 |
+
topn : int
|
| 589 |
+
Integer corresponding to the number of top words to extract.
|
| 590 |
+
|
| 591 |
+
Returns
|
| 592 |
+
-------
|
| 593 |
+
list of :class:`numpy.ndarray`
|
| 594 |
+
A list where each element is a numpy array of word IDs representing a topic's top words.
|
| 595 |
+
|
| 596 |
+
Raises
|
| 597 |
+
------
|
| 598 |
+
AttributeError
|
| 599 |
+
If the provided model does not implement a `get_topics` method.
|
| 600 |
+
"""
|
| 601 |
+
try:
|
| 602 |
+
# Iterate over the topic distributions from the model
|
| 603 |
+
# Use matutils.argsort to get the indices (word IDs) of the top `topn` words
|
| 604 |
+
return [
|
| 605 |
+
matutils.argsort(topic, topn=topn, reverse=True) for topic in
|
| 606 |
+
model.get_topics()
|
| 607 |
+
]
|
| 608 |
+
except AttributeError:
|
| 609 |
+
raise ValueError(
|
| 610 |
+
"This topic model is not currently supported. Supported topic models"
|
| 611 |
+
" should implement the `get_topics` method.")
|
| 612 |
+
|
| 613 |
+
def segment_topics(self):
|
| 614 |
+
"""
|
| 615 |
+
Segments the current topics using the segmentation function defined by the
|
| 616 |
+
chosen coherence measure (`self.measure.seg`).
|
| 617 |
+
|
| 618 |
+
Returns
|
| 619 |
+
-------
|
| 620 |
+
list of list of tuple
|
| 621 |
+
Segmented topics. The structure depends on the segmentation method (e.g., pairs of word IDs).
|
| 622 |
+
"""
|
| 623 |
+
# Apply the segmentation function from the pipeline to the current topics
|
| 624 |
+
return self.measure.seg(self.topics)
|
| 625 |
+
|
| 626 |
+
def estimate_probabilities(self, segmented_topics=None):
|
| 627 |
+
"""
|
| 628 |
+
Accumulates word occurrences and co-occurrences from texts or corpus
|
| 629 |
+
using the optimal probability estimation method for the chosen coherence metric.
|
| 630 |
+
This operation can be computationally intensive, especially for sliding window methods.
|
| 631 |
+
|
| 632 |
+
Parameters
|
| 633 |
+
----------
|
| 634 |
+
segmented_topics : list of list of tuple, optional
|
| 635 |
+
Segmented topics. If None, `self.segment_topics()` is called internally.
|
| 636 |
+
|
| 637 |
+
Returns
|
| 638 |
+
-------
|
| 639 |
+
:class:`~gensim.topic_coherence.text_analysis.CorpusAccumulator`
|
| 640 |
+
An object that holds the accumulated statistics (word frequencies, co-occurrence frequencies).
|
| 641 |
+
"""
|
| 642 |
+
if segmented_topics is None:
|
| 643 |
+
segmented_topics = self.segment_topics()
|
| 644 |
+
|
| 645 |
+
# Choose the appropriate probability estimation method based on the coherence type
|
| 646 |
+
if self.coherence in BOOLEAN_DOCUMENT_BASED:
|
| 647 |
+
self._accumulator = self.measure.prob(self.corpus, segmented_topics)
|
| 648 |
+
else:
|
| 649 |
+
kwargs = dict(
|
| 650 |
+
texts=self.texts, segmented_topics=segmented_topics,
|
| 651 |
+
dictionary=self.dictionary, window_size=self.window_size,
|
| 652 |
+
processes=self.processes)
|
| 653 |
+
if self.coherence == 'c_w2v':
|
| 654 |
+
kwargs['model'] = self.keyed_vectors # Pass keyed_vectors for word2vec based coherence
|
| 655 |
+
|
| 656 |
+
self._accumulator = self.measure.prob(**kwargs)
|
| 657 |
+
|
| 658 |
+
return self._accumulator
|
| 659 |
+
|
| 660 |
+
def get_coherence_per_topic(self, segmented_topics=None, with_std=False, with_support=False):
|
| 661 |
+
"""
|
| 662 |
+
Calculates and returns a list of coherence values, one for each topic,
|
| 663 |
+
based on the pipeline's confirmation measure.
|
| 664 |
+
|
| 665 |
+
Parameters
|
| 666 |
+
----------
|
| 667 |
+
segmented_topics : list of list of tuple, optional
|
| 668 |
+
Segmented topics. If None, `self.segment_topics()` is called internally.
|
| 669 |
+
with_std : bool, optional
|
| 670 |
+
If True, also includes the standard deviation across topic segment sets in addition
|
| 671 |
+
to the mean coherence for each topic. Defaults to False.
|
| 672 |
+
with_support : bool, optional
|
| 673 |
+
If True, also includes the "support" (number of pairwise similarity comparisons)
|
| 674 |
+
used to compute each topic's coherence. Defaults to False.
|
| 675 |
+
|
| 676 |
+
Returns
|
| 677 |
+
-------
|
| 678 |
+
list of float or list of tuple
|
| 679 |
+
A sequence of similarity measures for each topic.
|
| 680 |
+
If `with_std` or `with_support` is True, each element in the list will be a tuple
|
| 681 |
+
containing the coherence value and the requested additional statistics.
|
| 682 |
+
"""
|
| 683 |
+
measure = self.measure
|
| 684 |
+
if segmented_topics is None:
|
| 685 |
+
segmented_topics = measure.seg(self.topics)
|
| 686 |
+
|
| 687 |
+
# Ensure probabilities are estimated before calculating coherence
|
| 688 |
+
if self._accumulator is None:
|
| 689 |
+
self.estimate_probabilities(segmented_topics)
|
| 690 |
+
|
| 691 |
+
kwargs = dict(with_std=with_std, with_support=with_support)
|
| 692 |
+
if self.coherence in BOOLEAN_DOCUMENT_BASED or self.coherence == 'c_w2v':
|
| 693 |
+
# These coherence types don't require specific additional kwargs for confirmation measure
|
| 694 |
+
pass
|
| 695 |
+
elif self.coherence == 'c_v':
|
| 696 |
+
# Specific kwargs for c_v's confirmation measure (cosine_similarity)
|
| 697 |
+
kwargs['topics'] = self.topics
|
| 698 |
+
kwargs['measure'] = 'nlr' # Normalized Log Ratio
|
| 699 |
+
kwargs['gamma'] = 1
|
| 700 |
+
else:
|
| 701 |
+
# For c_uci and c_npmi, 'normalize' parameter is relevant
|
| 702 |
+
kwargs['normalize'] = (self.coherence == 'c_npmi')
|
| 703 |
+
|
| 704 |
+
return measure.conf(segmented_topics, self._accumulator, **kwargs)
|
| 705 |
+
|
| 706 |
+
def aggregate_measures(self, topic_coherences):
|
| 707 |
+
"""
|
| 708 |
+
Aggregates the individual topic coherence measures into a single overall score
|
| 709 |
+
using the pipeline's aggregation function (`self.measure.aggr`).
|
| 710 |
+
|
| 711 |
+
Parameters
|
| 712 |
+
----------
|
| 713 |
+
topic_coherences : list of float
|
| 714 |
+
List of coherence values for each topic.
|
| 715 |
+
|
| 716 |
+
Returns
|
| 717 |
+
-------
|
| 718 |
+
float
|
| 719 |
+
The aggregated coherence value (e.g., arithmetic mean).
|
| 720 |
+
"""
|
| 721 |
+
# Apply the aggregation function from the pipeline to the list of topic coherences
|
| 722 |
+
return self.measure.aggr(topic_coherences)
|
| 723 |
+
|
| 724 |
+
def get_coherence(self):
|
| 725 |
+
"""
|
| 726 |
+
Calculates and returns the overall coherence value for the entire set of topics.
|
| 727 |
+
This is the main entry point for getting a single coherence score.
|
| 728 |
+
|
| 729 |
+
Returns
|
| 730 |
+
-------
|
| 731 |
+
float
|
| 732 |
+
The aggregated coherence value.
|
| 733 |
+
"""
|
| 734 |
+
# First, get coherence values for each individual topic
|
| 735 |
+
confirmed_measures = self.get_coherence_per_topic()
|
| 736 |
+
# Then, aggregate these topic-level coherences into a single score
|
| 737 |
+
return self.aggregate_measures(confirmed_measures)
|
| 738 |
+
|
| 739 |
+
def compare_models(self, models):
|
| 740 |
+
"""
|
| 741 |
+
Compares multiple topic models by their coherence values.
|
| 742 |
+
It extracts topics from each model and then calls `compare_model_topics`.
|
| 743 |
+
|
| 744 |
+
Parameters
|
| 745 |
+
----------
|
| 746 |
+
models : list of :class:`~gensim.models.basemodel.BaseTopicModel`
|
| 747 |
+
A sequence of topic models to compare.
|
| 748 |
+
|
| 749 |
+
Returns
|
| 750 |
+
-------
|
| 751 |
+
list of (list of float, float)
|
| 752 |
+
A sequence where each element is a pair:
|
| 753 |
+
(list of average topic coherences for the model, overall model coherence).
|
| 754 |
+
"""
|
| 755 |
+
# Extract topics (as word IDs) for each model using the internal helper
|
| 756 |
+
model_topics = [self._get_topics_from_model(model, self.topn) for model in models]
|
| 757 |
+
# Delegate to compare_model_topics for the actual coherence comparison
|
| 758 |
+
return self.compare_model_topics(model_topics)
|
| 759 |
+
|
| 760 |
+
def compare_model_topics(self, model_topics):
|
| 761 |
+
"""
|
| 762 |
+
Performs coherence evaluation for each set of topics provided in `model_topics`.
|
| 763 |
+
This method is designed to be efficient by precomputing probabilities once if needed,
|
| 764 |
+
and then evaluating coherence for each set of topics.
|
| 765 |
+
|
| 766 |
+
Parameters
|
| 767 |
+
----------
|
| 768 |
+
model_topics : list of list of list of int
|
| 769 |
+
A list where each element is itself a list of topics (each topic being a list of word IDs)
|
| 770 |
+
representing a set of topics (e.g., from a single model).
|
| 771 |
+
|
| 772 |
+
Returns
|
| 773 |
+
-------
|
| 774 |
+
list of (list of float, float)
|
| 775 |
+
A sequence where each element is a pair:
|
| 776 |
+
(list of average topic coherences for the topic set, overall topic set coherence).
|
| 777 |
+
|
| 778 |
+
Notes
|
| 779 |
+
-----
|
| 780 |
+
This method uses a heuristic of evaluating coherence at various `topn` values (e.g., 20, 15, 10, 5)
|
| 781 |
+
and averaging the results for robustness, as suggested in some research.
|
| 782 |
+
"""
|
| 783 |
+
# Store original topics and topn to restore them after comparison
|
| 784 |
+
orig_topics = self._topics
|
| 785 |
+
orig_topn = self.topn
|
| 786 |
+
|
| 787 |
+
try:
|
| 788 |
+
# Perform the actual comparison
|
| 789 |
+
coherences = self._compare_model_topics(model_topics)
|
| 790 |
+
finally:
|
| 791 |
+
# Ensure original topics and topn are restored even if an error occurs
|
| 792 |
+
self.topics = orig_topics
|
| 793 |
+
self.topn = orig_topn
|
| 794 |
+
|
| 795 |
+
return coherences
|
| 796 |
+
|
| 797 |
+
def _compare_model_topics(self, model_topics):
|
| 798 |
+
"""
|
| 799 |
+
Internal helper to get average topic and model coherences across multiple sets of topics.
|
| 800 |
+
|
| 801 |
+
Parameters
|
| 802 |
+
----------
|
| 803 |
+
model_topics : list of list of list of int
|
| 804 |
+
A list where each element is a set of topics (list of lists of word IDs).
|
| 805 |
+
|
| 806 |
+
Returns
|
| 807 |
+
-------
|
| 808 |
+
list of (list of float, float)
|
| 809 |
+
A sequence of pairs:
|
| 810 |
+
(average topic coherences across different `topn` values for each topic,
|
| 811 |
+
overall model coherence averaged across different `topn` values).
|
| 812 |
+
"""
|
| 813 |
+
coherences = []
|
| 814 |
+
# Define a grid of `topn` values to evaluate coherence.
|
| 815 |
+
# This provides a more robust average coherence value.
|
| 816 |
+
# It goes from `self.topn` down to `min(self.topn - 1, 4)` in steps of -5.
|
| 817 |
+
# e.g., if self.topn is 20, grid might be [20, 15, 10, 5].
|
| 818 |
+
# The `min(self.topn - 1, 4)` ensures at least some lower values are included,
|
| 819 |
+
# but also prevents trying `topn` values that are too small or negative.
|
| 820 |
+
last_topn_value = min(self.topn - 1, 4)
|
| 821 |
+
topn_grid = list(range(self.topn, last_topn_value, -5))
|
| 822 |
+
if not topn_grid or max(topn_grid) < 1: # Ensure at least one valid topn if range is empty or too small
|
| 823 |
+
topn_grid = [max(1, min(self.topn, 5))] # Use min of self.topn and 5, ensure at least 1
|
| 824 |
+
|
| 825 |
+
for model_num, topics in enumerate(model_topics):
|
| 826 |
+
# Set the current topics for the instance to the topics of the model being evaluated
|
| 827 |
+
self.topics = topics
|
| 828 |
+
|
| 829 |
+
coherence_at_n = {} # Dictionary to store coherence results for different `topn` values
|
| 830 |
+
for n in topn_grid:
|
| 831 |
+
self.topn = n # Set the `topn` for the current evaluation round
|
| 832 |
+
topic_coherences = self.get_coherence_per_topic()
|
| 833 |
+
|
| 834 |
+
# Handle NaN values in topic coherences by imputing with the mean
|
| 835 |
+
filled_coherences = np.array(topic_coherences, dtype=float)
|
| 836 |
+
# Check for NaN values and replace them with the mean of non-NaN values.
|
| 837 |
+
# np.nanmean handles arrays with all NaNs gracefully by returning NaN.
|
| 838 |
+
if np.any(np.isnan(filled_coherences)):
|
| 839 |
+
mean_val = np.nanmean(filled_coherences)
|
| 840 |
+
if np.isnan(mean_val): # If all are NaN, mean_val will also be NaN. In this case, replace with 0 or a very small number.
|
| 841 |
+
filled_coherences[np.isnan(filled_coherences)] = 0.0 # Or another sensible default
|
| 842 |
+
else:
|
| 843 |
+
filled_coherences[np.isnan(filled_coherences)] = mean_val
|
| 844 |
+
|
| 845 |
+
|
| 846 |
+
# Store the topic-level coherences and the aggregated (overall) coherence for this `topn`
|
| 847 |
+
coherence_at_n[n] = (topic_coherences, self.aggregate_measures(filled_coherences))
|
| 848 |
+
|
| 849 |
+
# Unpack the stored coherences for different `topn` values
|
| 850 |
+
all_topic_coherences_at_n, all_avg_coherences_at_n = zip(*coherence_at_n.values())
|
| 851 |
+
|
| 852 |
+
# Calculate the average topic coherence across all `topn` values
|
| 853 |
+
# np.vstack stacks lists of topic coherences into a 2D array, then mean(0) computes mean for each topic.
|
| 854 |
+
avg_topic_coherences = np.vstack(all_topic_coherences_at_n).mean(axis=0)
|
| 855 |
+
|
| 856 |
+
# Calculate the overall model coherence by averaging the aggregated coherences from all `topn` values
|
| 857 |
+
model_coherence = np.mean(all_avg_coherences_at_n)
|
| 858 |
+
|
| 859 |
+
logging.info("Avg coherence for model %d: %.5f" % (model_num, model_coherence))
|
| 860 |
+
coherences.append((avg_topic_coherences.tolist(), model_coherence)) # Convert numpy array back to list for output
|
| 861 |
+
|
| 862 |
+
return coherences
|
backend/evaluation/eval.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# dynamic_topic_quality.py
|
| 2 |
+
import numpy as np
|
| 3 |
+
import pandas as pd
|
| 4 |
+
from gensim.corpora.dictionary import Dictionary
|
| 5 |
+
from gensim.models.coherencemodel import CoherenceModel
|
| 6 |
+
from backend.evaluation.CoherenceModel_ttc import CoherenceModel_ttc
|
| 7 |
+
from typing import List, Dict
|
| 8 |
+
|
| 9 |
+
class TopicQualityAssessor:
|
| 10 |
+
"""
|
| 11 |
+
Calculates various quality metrics for dynamic topic models from in-memory data.
|
| 12 |
+
|
| 13 |
+
This class provides methods to compute:
|
| 14 |
+
- Temporal Topic Coherence (TTC)
|
| 15 |
+
- Temporal Topic Smoothness (TTS)
|
| 16 |
+
- Temporal Topic Quality (TTQ)
|
| 17 |
+
- Yearly Topic Coherence (TC)
|
| 18 |
+
- Yearly Topic Diversity (TD)
|
| 19 |
+
- Yearly Topic Quality (TQ)
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, topics: List[List[List[str]]], train_texts: List[List[str]], topn: int, coherence_type: str):
|
| 23 |
+
"""
|
| 24 |
+
Initializes the TopicQualityAssessor with data in memory.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
topics (List[List[List[str]]]): A nested list of topics with structure (T, K, W),
|
| 28 |
+
where T is time slices, K is topics, and W is words.
|
| 29 |
+
train_texts (List[List[str]]): A list of tokenized documents for the reference corpus.
|
| 30 |
+
topn (int): Number of top words per topic to consider for calculations.
|
| 31 |
+
coherence_type (str): The type of coherence to calculate (e.g., 'c_npmi', 'c_v').
|
| 32 |
+
"""
|
| 33 |
+
# 1. Set texts and dictionary
|
| 34 |
+
self.texts = train_texts
|
| 35 |
+
self.dictionary = Dictionary(self.texts)
|
| 36 |
+
|
| 37 |
+
# 2. Process topics
|
| 38 |
+
# User provides topics as (T, K, W) -> List[timestamps][topics][words]
|
| 39 |
+
# Internal representation for temporal evolution is (K, T, W)
|
| 40 |
+
topics_array_T_K_W = np.array(topics, dtype=object)
|
| 41 |
+
if topics_array_T_K_W.ndim != 3:
|
| 42 |
+
raise ValueError(f"Input 'topics' must be a 3-dimensional list/array. Got {topics_array_T_K_W.ndim} dimensions.")
|
| 43 |
+
self.total_topics = topics_array_T_K_W.transpose(1, 0, 2) # Shape: (K, T, W)
|
| 44 |
+
|
| 45 |
+
# 3. Get dimensions
|
| 46 |
+
self.K, self.T, _ = self.total_topics.shape
|
| 47 |
+
|
| 48 |
+
# 4. Create topic groups for smoothness calculation (pairs of topics over time)
|
| 49 |
+
groups = []
|
| 50 |
+
for k in range(self.K):
|
| 51 |
+
time_pairs = []
|
| 52 |
+
for t in range(self.T - 1):
|
| 53 |
+
time_pairs.append([self.total_topics[k, t].tolist(), self.total_topics[k, t+1].tolist()])
|
| 54 |
+
groups.append(time_pairs)
|
| 55 |
+
self.group_topics = np.array(groups, dtype=object)
|
| 56 |
+
|
| 57 |
+
# 5. Create yearly topics (T, K, W) for TC/TD calculation
|
| 58 |
+
self.yearly_topics = self.total_topics.transpose(1, 0, 2)
|
| 59 |
+
|
| 60 |
+
# 6. Set parameters
|
| 61 |
+
self.topn = topn
|
| 62 |
+
self.coherence_type = coherence_type
|
| 63 |
+
|
| 64 |
+
def _compute_coherence(self, topics: List[List[str]]) -> List[float]:
|
| 65 |
+
cm = CoherenceModel(
|
| 66 |
+
topics=topics, texts=self.texts, dictionary=self.dictionary,
|
| 67 |
+
coherence=self.coherence_type, topn=self.topn
|
| 68 |
+
)
|
| 69 |
+
return cm.get_coherence_per_topic()
|
| 70 |
+
|
| 71 |
+
def _compute_coherence_ttc(self, topics: List[List[str]]) -> List[float]:
|
| 72 |
+
cm = CoherenceModel_ttc(
|
| 73 |
+
topics=topics, texts=self.texts, dictionary=self.dictionary,
|
| 74 |
+
coherence=self.coherence_type, topn=self.topn
|
| 75 |
+
)
|
| 76 |
+
return cm.get_coherence_per_topic()
|
| 77 |
+
|
| 78 |
+
def _topic_smoothness(self, topics: List[List[str]]) -> float:
|
| 79 |
+
K = len(topics)
|
| 80 |
+
if K <= 1:
|
| 81 |
+
return 1.0 # Or 0.0, depending on definition. A single topic has no other topic to be dissimilar to.
|
| 82 |
+
scores = []
|
| 83 |
+
for i, base in enumerate(topics):
|
| 84 |
+
base_set = set(base[:self.topn])
|
| 85 |
+
others = [other for j, other in enumerate(topics) if j != i]
|
| 86 |
+
if not others:
|
| 87 |
+
return 1.0
|
| 88 |
+
overlaps = [len(base_set & set(other[:self.topn])) / self.topn for other in others]
|
| 89 |
+
scores.append(sum(overlaps) / len(overlaps))
|
| 90 |
+
return float(sum(scores) / K)
|
| 91 |
+
|
| 92 |
+
def get_ttq_dataframe(self) -> pd.DataFrame:
|
| 93 |
+
"""Computes and returns a DataFrame with detailed TTQ metrics per topic chain."""
|
| 94 |
+
all_coh_scores, avg_coh_scores = [], []
|
| 95 |
+
for k in range(self.K):
|
| 96 |
+
coh_per_topic = self._compute_coherence_ttc(self.total_topics[k].tolist())
|
| 97 |
+
all_coh_scores.append(coh_per_topic)
|
| 98 |
+
avg_coh_scores.append(float(np.mean(coh_per_topic)))
|
| 99 |
+
|
| 100 |
+
all_smooth_scores, avg_smooth_scores = [], []
|
| 101 |
+
for k in range(self.K):
|
| 102 |
+
pair_scores = [self._topic_smoothness(pair) for pair in self.group_topics[k]]
|
| 103 |
+
all_smooth_scores.append(pair_scores)
|
| 104 |
+
avg_smooth_scores.append(float(np.mean(pair_scores)))
|
| 105 |
+
|
| 106 |
+
df = pd.DataFrame({
|
| 107 |
+
'topic_idx': list(range(self.K)),
|
| 108 |
+
'temporal_coherence': all_coh_scores,
|
| 109 |
+
'temporal_smoothness': all_smooth_scores,
|
| 110 |
+
'avg_temporal_coherence': avg_coh_scores,
|
| 111 |
+
'avg_temporal_smoothness': avg_smooth_scores
|
| 112 |
+
})
|
| 113 |
+
df['ttq_product'] = df['avg_temporal_coherence'] * df['avg_temporal_smoothness']
|
| 114 |
+
return df
|
| 115 |
+
|
| 116 |
+
def get_tq_dataframe(self) -> pd.DataFrame:
|
| 117 |
+
"""Computes and returns a DataFrame with detailed TQ metrics per time slice."""
|
| 118 |
+
all_coh, avg_coh, div = [], [], []
|
| 119 |
+
for t in range(self.T):
|
| 120 |
+
yearly_t_topics = self.yearly_topics[t].tolist()
|
| 121 |
+
coh_per_topic = self._compute_coherence(yearly_t_topics)
|
| 122 |
+
all_coh.append(coh_per_topic)
|
| 123 |
+
avg_coh.append(float(np.mean(coh_per_topic)))
|
| 124 |
+
div.append(1 - self._topic_smoothness(yearly_t_topics))
|
| 125 |
+
|
| 126 |
+
df = pd.DataFrame({
|
| 127 |
+
'year': list(range(self.T)),
|
| 128 |
+
'all_coherence': all_coh,
|
| 129 |
+
'avg_coherence': avg_coh,
|
| 130 |
+
'diversity': div
|
| 131 |
+
})
|
| 132 |
+
df['tq_product'] = df['avg_coherence'] * df['diversity']
|
| 133 |
+
return df
|
| 134 |
+
|
| 135 |
+
def get_ttc_score(self) -> float:
|
| 136 |
+
"""Calculates the overall Temporal Topic Coherence (TTC)."""
|
| 137 |
+
ttq_df = self.get_ttq_dataframe()
|
| 138 |
+
return ttq_df['avg_temporal_coherence'].mean()
|
| 139 |
+
|
| 140 |
+
def get_tts_score(self) -> float:
|
| 141 |
+
"""Calculates the overall Temporal Topic Smoothness (TTS)."""
|
| 142 |
+
ttq_df = self.get_ttq_dataframe()
|
| 143 |
+
return ttq_df['avg_temporal_smoothness'].mean()
|
| 144 |
+
|
| 145 |
+
def get_ttq_score(self) -> float:
|
| 146 |
+
"""Calculates the overall Temporal Topic Quality (TTQ)."""
|
| 147 |
+
ttq_df = self.get_ttq_dataframe()
|
| 148 |
+
return ttq_df['ttq_product'].mean()
|
| 149 |
+
|
| 150 |
+
def get_tc_score(self) -> float:
|
| 151 |
+
"""Calculates the overall yearly Topic Coherence (TC)."""
|
| 152 |
+
tq_df = self.get_tq_dataframe()
|
| 153 |
+
return tq_df['avg_coherence'].mean()
|
| 154 |
+
|
| 155 |
+
def get_td_score(self) -> float:
|
| 156 |
+
"""Calculates the overall yearly Topic Diversity (TD)."""
|
| 157 |
+
tq_df = self.get_tq_dataframe()
|
| 158 |
+
return tq_df['diversity'].mean()
|
| 159 |
+
|
| 160 |
+
def get_tq_score(self) -> float:
|
| 161 |
+
"""Calculates the overall yearly Topic Quality (TQ)."""
|
| 162 |
+
tq_df = self.get_tq_dataframe()
|
| 163 |
+
return tq_df['tq_product'].mean()
|
| 164 |
+
|
| 165 |
+
def get_dtq_summary(self) -> Dict[str, float]:
|
| 166 |
+
"""
|
| 167 |
+
Computes all dynamic topic quality metrics and returns them in a dictionary.
|
| 168 |
+
"""
|
| 169 |
+
ttq_df = self.get_ttq_dataframe()
|
| 170 |
+
tq_df = self.get_tq_dataframe()
|
| 171 |
+
summary = {
|
| 172 |
+
'TTC': ttq_df['avg_temporal_coherence'].mean(),
|
| 173 |
+
'TTS': ttq_df['avg_temporal_smoothness'].mean(),
|
| 174 |
+
'TTQ': ttq_df['ttq_product'].mean(),
|
| 175 |
+
'TC': tq_df['avg_coherence'].mean(),
|
| 176 |
+
'TD': tq_df['diversity'].mean(),
|
| 177 |
+
'TQ': tq_df['tq_product'].mean()
|
| 178 |
+
}
|
| 179 |
+
return summary
|
backend/inference/doc_retriever.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import html
|
| 2 |
+
import json
|
| 3 |
+
import re
|
| 4 |
+
import os
|
| 5 |
+
from hashlib import md5
|
| 6 |
+
|
| 7 |
+
def deduplicate_docs(collected_docs):
|
| 8 |
+
seen = set()
|
| 9 |
+
unique_docs = []
|
| 10 |
+
for doc in collected_docs:
|
| 11 |
+
# Prefer unique ID if available
|
| 12 |
+
key = doc.get("id", md5(doc["text"].encode()).hexdigest())
|
| 13 |
+
if key not in seen:
|
| 14 |
+
seen.add(key)
|
| 15 |
+
unique_docs.append(doc)
|
| 16 |
+
return unique_docs
|
| 17 |
+
|
| 18 |
+
def load_length_stats(length_stats_path):
|
| 19 |
+
"""
|
| 20 |
+
Loads length statistics from a JSON file for a given model path.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
path (str): Path to the model directory containing 'length_stats.json'.
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
dict: A dictionary containing document length statistics.
|
| 27 |
+
"""
|
| 28 |
+
if not os.path.exists(length_stats_path):
|
| 29 |
+
raise FileNotFoundError(f"'length_stats.json' not found at: {length_stats_path}")
|
| 30 |
+
|
| 31 |
+
with open(length_stats_path, "r") as f:
|
| 32 |
+
length_stats = json.load(f)
|
| 33 |
+
|
| 34 |
+
return length_stats
|
| 35 |
+
|
| 36 |
+
def get_yearly_counts_for_word(index, word):
|
| 37 |
+
if word not in index:
|
| 38 |
+
print(f"[ERROR] Word '{word}' not found in index.")
|
| 39 |
+
return [], []
|
| 40 |
+
|
| 41 |
+
year_counts = index[word]
|
| 42 |
+
sorted_items = sorted((int(year), len(doc_ids)) for year, doc_ids in year_counts.items())
|
| 43 |
+
years, counts = zip(*sorted_items) if sorted_items else ([], [])
|
| 44 |
+
return list(years), list(counts)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def get_all_documents_for_word_year(index, docs_file_path, word, year):
|
| 48 |
+
"""
|
| 49 |
+
Returns all full documents (text + metadata) that contain a given word in a given year.
|
| 50 |
+
|
| 51 |
+
Parameters:
|
| 52 |
+
index (dict): Inverted index.
|
| 53 |
+
docs_file_path (str): Path to original jsonl corpus.
|
| 54 |
+
word (str): Word (unigram or bigram).
|
| 55 |
+
year (int): Year to retrieve docs for.
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
List[Dict]: List of documents with 'id', 'timestamp', and 'text'.
|
| 59 |
+
"""
|
| 60 |
+
year = int(year)
|
| 61 |
+
|
| 62 |
+
if word not in index or year not in index[word]:
|
| 63 |
+
return []
|
| 64 |
+
|
| 65 |
+
doc_ids = set(index[word][year])
|
| 66 |
+
results = []
|
| 67 |
+
|
| 68 |
+
try:
|
| 69 |
+
with open(docs_file_path, 'r', encoding='utf-8') as f:
|
| 70 |
+
for doc_id, line in enumerate(f):
|
| 71 |
+
if doc_id in doc_ids:
|
| 72 |
+
doc = json.loads(line)
|
| 73 |
+
results.append({
|
| 74 |
+
"id": doc_id,
|
| 75 |
+
"timestamp": doc.get("timestamp", "N/A"),
|
| 76 |
+
"text": doc["text"]
|
| 77 |
+
})
|
| 78 |
+
except Exception as e:
|
| 79 |
+
print(f"[ERROR] Could not load documents: {e}")
|
| 80 |
+
|
| 81 |
+
return results
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def get_documents_with_all_words_for_year(index, docs_path, words, year):
|
| 85 |
+
doc_sets = []
|
| 86 |
+
all_doc_occurrences = {}
|
| 87 |
+
|
| 88 |
+
for word in words:
|
| 89 |
+
word_docs = get_all_documents_for_word_year(index, docs_path, word, year)
|
| 90 |
+
doc_sets.append(set(doc["id"] for doc in word_docs))
|
| 91 |
+
for doc in word_docs:
|
| 92 |
+
all_doc_occurrences.setdefault(doc["id"], doc)
|
| 93 |
+
|
| 94 |
+
common_doc_ids = set.intersection(*doc_sets) if doc_sets else set()
|
| 95 |
+
return [all_doc_occurrences[doc_id] for doc_id in common_doc_ids]
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def get_intersection_doc_counts_by_year(index, docs_path, words, all_years):
|
| 99 |
+
year_counts = {}
|
| 100 |
+
for y in all_years:
|
| 101 |
+
docs = get_documents_with_all_words_for_year(index, docs_path, words, y)
|
| 102 |
+
year_counts[y] = len(docs)
|
| 103 |
+
return year_counts
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def extract_snippet(text, query, window=30):
|
| 107 |
+
"""
|
| 108 |
+
Return a short snippet around the first occurrence of the query word.
|
| 109 |
+
"""
|
| 110 |
+
pattern = re.compile(re.escape(query.replace('_', ' ')), re.IGNORECASE)
|
| 111 |
+
match = pattern.search(text)
|
| 112 |
+
if not match:
|
| 113 |
+
return text[:200] + "..."
|
| 114 |
+
|
| 115 |
+
start = max(match.start() - window, 0)
|
| 116 |
+
end = min(match.end() + window, len(text))
|
| 117 |
+
snippet = text[start:end].strip()
|
| 118 |
+
|
| 119 |
+
return f"...{snippet}..."
|
| 120 |
+
|
| 121 |
+
def highlight(text, query, highlight_color="#FFD54F"):
|
| 122 |
+
"""
|
| 123 |
+
Highlight all instances of the query term in text using a colored <mark> tag.
|
| 124 |
+
"""
|
| 125 |
+
escaped_query = re.escape(query.replace('_', ' '))
|
| 126 |
+
pattern = re.compile(f"({escaped_query})", flags=re.IGNORECASE)
|
| 127 |
+
|
| 128 |
+
def replacer(match):
|
| 129 |
+
matched_text = html.escape(match.group(1))
|
| 130 |
+
return f"<mark style='background-color:{highlight_color}; color:black;'>{matched_text}</mark>"
|
| 131 |
+
|
| 132 |
+
return pattern.sub(replacer, html.escape(text))
|
| 133 |
+
|
| 134 |
+
def highlight_words(text, query_words, highlight_color="#24F31D", lemma_to_forms=None):
|
| 135 |
+
"""
|
| 136 |
+
Highlight all surface forms of each query lemma in the text using a colored <mark> tag.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
text (str): The input raw document text.
|
| 140 |
+
query_words (List[str]): Lemmatized query tokens to highlight.
|
| 141 |
+
highlight_color (str): Color to use for highlighting.
|
| 142 |
+
lemma_to_forms (Dict[str, Set[str]]): Maps a lemma to its surface forms.
|
| 143 |
+
"""
|
| 144 |
+
# Escape HTML special characters first
|
| 145 |
+
escaped_text = html.escape(text)
|
| 146 |
+
|
| 147 |
+
# Expand query words to include all surface forms
|
| 148 |
+
expanded_forms = set()
|
| 149 |
+
for lemma in query_words:
|
| 150 |
+
if lemma_to_forms and lemma in lemma_to_forms:
|
| 151 |
+
expanded_forms.update(lemma_to_forms[lemma])
|
| 152 |
+
else:
|
| 153 |
+
expanded_forms.add(lemma) # Fallback if map is missing
|
| 154 |
+
|
| 155 |
+
# Sort by length to avoid partial overlaps (e.g., "run" before "running")
|
| 156 |
+
sorted_queries = sorted(expanded_forms, key=lambda w: -len(w))
|
| 157 |
+
|
| 158 |
+
for word in sorted_queries:
|
| 159 |
+
# Match full word, case insensitive
|
| 160 |
+
pattern = re.compile(rf'\b({re.escape(word)})\b', flags=re.IGNORECASE)
|
| 161 |
+
|
| 162 |
+
def replacer(match):
|
| 163 |
+
matched_text = match.group(1)
|
| 164 |
+
return f"<mark style='background-color:{highlight_color}; color:black;'>{matched_text}</mark>"
|
| 165 |
+
|
| 166 |
+
escaped_text = pattern.sub(replacer, escaped_text)
|
| 167 |
+
|
| 168 |
+
return escaped_text
|
| 169 |
+
|
| 170 |
+
def get_docs_by_ids(docs_file_path, doc_ids):
|
| 171 |
+
"""
|
| 172 |
+
Efficiently retrieves specific documents from a .jsonl file by their line number (ID).
|
| 173 |
+
|
| 174 |
+
This function reads the file line-by-line and only parses the lines that match
|
| 175 |
+
the requested document IDs, avoiding loading the entire file into memory.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
docs_file_path (str): The path to the documents.jsonl file.
|
| 179 |
+
doc_ids (list or set): A collection of document IDs (0-indexed line numbers) to retrieve.
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
list[dict]: A list of document dictionaries that were found. Each dictionary
|
| 183 |
+
is augmented with an 'id' key corresponding to its line number.
|
| 184 |
+
"""
|
| 185 |
+
# Use a set for efficient O(1) lookups.
|
| 186 |
+
doc_ids_to_find = set(doc_ids)
|
| 187 |
+
found_docs = {}
|
| 188 |
+
|
| 189 |
+
if not doc_ids_to_find:
|
| 190 |
+
return []
|
| 191 |
+
|
| 192 |
+
try:
|
| 193 |
+
with open(docs_file_path, 'r', encoding='utf-8') as f:
|
| 194 |
+
for i, line in enumerate(f):
|
| 195 |
+
# If the current line number is one we're looking for
|
| 196 |
+
if i in doc_ids_to_find:
|
| 197 |
+
try:
|
| 198 |
+
doc = json.loads(line)
|
| 199 |
+
# Explicitly add the line number as the 'id'
|
| 200 |
+
doc['id'] = i
|
| 201 |
+
found_docs[i] = doc
|
| 202 |
+
# Optimization: stop reading the file once all docs are found
|
| 203 |
+
if len(found_docs) == len(doc_ids_to_find):
|
| 204 |
+
break
|
| 205 |
+
except json.JSONDecodeError:
|
| 206 |
+
# Skip malformed lines but inform the user
|
| 207 |
+
print(f"[WARNING] Skipping malformed JSON on line {i+1} in {docs_file_path}")
|
| 208 |
+
continue
|
| 209 |
+
|
| 210 |
+
except FileNotFoundError:
|
| 211 |
+
print(f"[ERROR] Document file not found at: {docs_file_path}")
|
| 212 |
+
return []
|
| 213 |
+
except Exception as e:
|
| 214 |
+
print(f"[ERROR] An unexpected error occurred while reading documents: {e}")
|
| 215 |
+
return []
|
| 216 |
+
|
| 217 |
+
# Return the documents in the same order as the original doc_ids list
|
| 218 |
+
# This ensures consistency for downstream processing.
|
| 219 |
+
return [found_docs[doc_id] for doc_id in doc_ids if doc_id in found_docs]
|
backend/inference/indexing_utils.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
import spacy
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
|
| 7 |
+
# Load spaCy once
|
| 8 |
+
nlp = spacy.load("en_core_web_sm", disable=["parser", "ner"])
|
| 9 |
+
|
| 10 |
+
def tokenize(text):
|
| 11 |
+
return re.findall(r"\b\w+\b", text.lower())
|
| 12 |
+
|
| 13 |
+
def has_bigram(tokens, bigram):
|
| 14 |
+
parts = bigram.split('_')
|
| 15 |
+
for i in range(len(tokens) - len(parts) + 1):
|
| 16 |
+
if tokens[i:i + len(parts)] == parts:
|
| 17 |
+
return True
|
| 18 |
+
return False
|
| 19 |
+
|
| 20 |
+
def build_inverse_lemma_map(docs_file_path, cache_path=None):
|
| 21 |
+
"""
|
| 22 |
+
Build or load a mapping from lemma -> set of surface forms seen in corpus.
|
| 23 |
+
If cache_path is provided and exists, loads from it.
|
| 24 |
+
Else builds from scratch and saves to cache_path.
|
| 25 |
+
"""
|
| 26 |
+
if cache_path and os.path.exists(cache_path):
|
| 27 |
+
print(f"[INFO] Loading cached lemma_to_forms from {cache_path}")
|
| 28 |
+
with open(cache_path, "r", encoding="utf-8") as f:
|
| 29 |
+
raw_map = json.load(f)
|
| 30 |
+
return {lemma: set(forms) for lemma, forms in raw_map.items()}
|
| 31 |
+
|
| 32 |
+
print(f"[INFO] Building inverse lemma map from {docs_file_path}...")
|
| 33 |
+
lemma_to_forms = defaultdict(set)
|
| 34 |
+
|
| 35 |
+
with open(docs_file_path, 'r', encoding='utf-8') as f:
|
| 36 |
+
for line in f:
|
| 37 |
+
doc = json.loads(line)
|
| 38 |
+
tokens = tokenize(doc['text'])
|
| 39 |
+
spacy_doc = nlp(" ".join(tokens))
|
| 40 |
+
for token in spacy_doc:
|
| 41 |
+
lemma_to_forms[token.lemma_].add(token.text.lower())
|
| 42 |
+
|
| 43 |
+
if cache_path:
|
| 44 |
+
print(f"[INFO] Saving lemma_to_forms to {cache_path}")
|
| 45 |
+
os.makedirs(os.path.dirname(cache_path), exist_ok=True)
|
| 46 |
+
with open(cache_path, "w", encoding="utf-8") as f:
|
| 47 |
+
json.dump({k: list(v) for k, v in lemma_to_forms.items()}, f, indent=2)
|
| 48 |
+
|
| 49 |
+
return lemma_to_forms
|
| 50 |
+
|
| 51 |
+
def build_inverted_index(docs_file_path, vocab_set, lemma_map_path=None):
|
| 52 |
+
vocab_unigrams = {w for w in vocab_set if '_' not in w}
|
| 53 |
+
vocab_bigrams = {w for w in vocab_set if '_' in w}
|
| 54 |
+
|
| 55 |
+
# Load or build lemma map
|
| 56 |
+
lemma_to_forms = build_inverse_lemma_map(docs_file_path, cache_path=lemma_map_path)
|
| 57 |
+
|
| 58 |
+
index = defaultdict(lambda: defaultdict(list))
|
| 59 |
+
docs = []
|
| 60 |
+
global_seen_words = set()
|
| 61 |
+
|
| 62 |
+
with open(docs_file_path, 'r', encoding='utf-8') as f:
|
| 63 |
+
for doc_id, line in enumerate(f):
|
| 64 |
+
doc = json.loads(line)
|
| 65 |
+
text = doc['text']
|
| 66 |
+
timestamp = int(doc['timestamp'])
|
| 67 |
+
docs.append({"text": text, "timestamp": timestamp})
|
| 68 |
+
|
| 69 |
+
tokens = tokenize(text)
|
| 70 |
+
token_set = set(tokens)
|
| 71 |
+
seen_words = set()
|
| 72 |
+
|
| 73 |
+
# Match all lemma queries using surface forms
|
| 74 |
+
for lemma in vocab_unigrams:
|
| 75 |
+
surface_forms = lemma_to_forms.get(lemma, set())
|
| 76 |
+
if token_set & surface_forms:
|
| 77 |
+
index[lemma][timestamp].append(doc_id)
|
| 78 |
+
seen_words.add(lemma)
|
| 79 |
+
|
| 80 |
+
for bigram in vocab_bigrams:
|
| 81 |
+
if bigram not in seen_words and has_bigram(tokens, bigram):
|
| 82 |
+
index[bigram][timestamp].append(doc_id)
|
| 83 |
+
seen_words.add(bigram)
|
| 84 |
+
|
| 85 |
+
global_seen_words.update(seen_words)
|
| 86 |
+
|
| 87 |
+
if (doc_id + 1) % 500 == 0:
|
| 88 |
+
missing = vocab_set - global_seen_words
|
| 89 |
+
print(f"[INFO] After {doc_id+1} docs, {len(missing)} vocab words still not seen.")
|
| 90 |
+
print("Example missing words:", list(missing)[:5])
|
| 91 |
+
|
| 92 |
+
missing_final = vocab_set - global_seen_words
|
| 93 |
+
if missing_final:
|
| 94 |
+
print(f"[WARNING] {len(missing_final)} vocab words were never found in any document.")
|
| 95 |
+
print("Examples:", list(missing_final)[:10])
|
| 96 |
+
|
| 97 |
+
return index, docs, lemma_to_forms
|
| 98 |
+
|
| 99 |
+
def save_index_to_disk(index, index_path):
|
| 100 |
+
index_clean = {
|
| 101 |
+
word: {str(ts): doc_ids for ts, doc_ids in ts_dict.items()}
|
| 102 |
+
for word, ts_dict in index.items()
|
| 103 |
+
}
|
| 104 |
+
os.makedirs(os.path.dirname(index_path), exist_ok=True)
|
| 105 |
+
with open(index_path, "w", encoding='utf-8') as f:
|
| 106 |
+
json.dump(index_clean, f, ensure_ascii=False)
|
| 107 |
+
|
| 108 |
+
def load_index_from_disk(index_path):
|
| 109 |
+
with open(index_path, 'r', encoding='utf-8') as f:
|
| 110 |
+
raw_index = json.load(f)
|
| 111 |
+
|
| 112 |
+
index = defaultdict(lambda: defaultdict(list))
|
| 113 |
+
for word, ts_dict in raw_index.items():
|
| 114 |
+
for ts, doc_ids in ts_dict.items():
|
| 115 |
+
index[word][int(ts)] = doc_ids
|
| 116 |
+
|
| 117 |
+
return index
|
| 118 |
+
|
| 119 |
+
def load_docs(docs_file_path):
|
| 120 |
+
docs = []
|
| 121 |
+
with open(docs_file_path, 'r', encoding='utf-8') as f:
|
| 122 |
+
for line in f:
|
| 123 |
+
doc = json.loads(line)
|
| 124 |
+
docs.append({
|
| 125 |
+
"text": doc["text"],
|
| 126 |
+
"timestamp": int(doc["timestamp"])
|
| 127 |
+
})
|
| 128 |
+
return docs
|
| 129 |
+
|
| 130 |
+
def load_index(docs_file_path, vocab, index_path=None, lemma_map_path=None):
|
| 131 |
+
if index_path and os.path.exists(index_path):
|
| 132 |
+
index = load_index_from_disk(index_path)
|
| 133 |
+
docs = load_docs(docs_file_path)
|
| 134 |
+
lemma_to_forms = build_inverse_lemma_map(docs_file_path, cache_path=lemma_map_path)
|
| 135 |
+
return index, docs, lemma_to_forms
|
| 136 |
+
|
| 137 |
+
index, docs, lemma_to_forms = build_inverted_index(
|
| 138 |
+
docs_file_path,
|
| 139 |
+
set(vocab),
|
| 140 |
+
lemma_map_path=lemma_map_path
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
if index_path:
|
| 144 |
+
save_index_to_disk(index, index_path)
|
| 145 |
+
|
| 146 |
+
return index, docs, lemma_to_forms
|
backend/inference/peak_detector.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from scipy.signal import find_peaks
|
| 3 |
+
|
| 4 |
+
def detect_peaks(trend, prominence=0.001, distance=2):
|
| 5 |
+
"""
|
| 6 |
+
Detect peaks in a word's trend over time.
|
| 7 |
+
|
| 8 |
+
Args:
|
| 9 |
+
trend: List or np.array of floats (word importance over time)
|
| 10 |
+
prominence: Required prominence of peaks (tune based on scale)
|
| 11 |
+
distance: Minimum distance between peaks
|
| 12 |
+
|
| 13 |
+
Returns:
|
| 14 |
+
List of indices (timestamps) where peaks occur
|
| 15 |
+
"""
|
| 16 |
+
trend = np.array(trend)
|
| 17 |
+
peaks, _ = find_peaks(trend, prominence=prominence, distance=distance)
|
| 18 |
+
return peaks.tolist()
|
backend/inference/process_beta.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import json
|
| 3 |
+
|
| 4 |
+
def load_beta_matrix(beta_path: str, vocab_path: str):
|
| 5 |
+
"""
|
| 6 |
+
Loads the beta matrix (T x K x V) and vocab list.
|
| 7 |
+
|
| 8 |
+
Returns:
|
| 9 |
+
beta: np.ndarray of shape (T, K, V)
|
| 10 |
+
vocab: list of words
|
| 11 |
+
"""
|
| 12 |
+
beta = np.load(beta_path) # shape: T x K x V
|
| 13 |
+
with open(vocab_path, 'r') as f:
|
| 14 |
+
vocab = [line.strip() for line in f.readlines()]
|
| 15 |
+
return beta, vocab
|
| 16 |
+
|
| 17 |
+
def get_top_words_at_time(beta, vocab, topic_id, time, top_n):
|
| 18 |
+
topic_beta = beta[time, topic_id, :]
|
| 19 |
+
top_indices = topic_beta.argsort()[-top_n:][::-1]
|
| 20 |
+
return [vocab[i] for i in top_indices]
|
| 21 |
+
|
| 22 |
+
def get_top_words_over_time(beta, vocab, topic_id, top_n):
|
| 23 |
+
topic_beta = beta[:, topic_id, :]
|
| 24 |
+
mean_beta = topic_beta.mean(axis=0)
|
| 25 |
+
top_indices = mean_beta.argsort()[-top_n:][::-1]
|
| 26 |
+
return [vocab[i] for i in top_indices]
|
| 27 |
+
|
| 28 |
+
def load_time_labels(time2id_path):
|
| 29 |
+
with open(time2id_path, 'r') as f:
|
| 30 |
+
time2id = json.load(f)
|
| 31 |
+
# Invert and sort by id
|
| 32 |
+
id2time = {v: k for k, v in time2id.items()}
|
| 33 |
+
return [id2time[i] for i in sorted(id2time)]
|
backend/inference/word_selector.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from scipy.special import softmax
|
| 3 |
+
|
| 4 |
+
def get_interesting_words(beta, vocab, topic_id, top_k_final=10, restrict_to=None):
|
| 5 |
+
"""
|
| 6 |
+
Suggests interesting words by prioritizing "bursty" or "emerging" terms,
|
| 7 |
+
making it effective at capturing important low-probability words.
|
| 8 |
+
|
| 9 |
+
This algorithm focuses on the ratio of a word's peak probability to its mean,
|
| 10 |
+
capturing words that show significant growth or have a sudden moment of high
|
| 11 |
+
relevance, even if their average probability is low.
|
| 12 |
+
|
| 13 |
+
Parameters:
|
| 14 |
+
- beta: np.ndarray (T, K, V) - Topic-word distributions for each timestamp.
|
| 15 |
+
- vocab: list of V words - The vocabulary.
|
| 16 |
+
- topic_id: int - The ID of the topic to analyze.
|
| 17 |
+
- top_k_final: int - The number of words to return.
|
| 18 |
+
- restrict_to: optional list of str - Restricts scoring to a subset of words.
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
- list of top_k_final interesting words (strings).
|
| 22 |
+
"""
|
| 23 |
+
T, K, V = beta.shape
|
| 24 |
+
|
| 25 |
+
# --- 1. Detect whether softmax is needed ---
|
| 26 |
+
row_sums = beta.sum(axis=2)
|
| 27 |
+
is_prob_dist = np.allclose(row_sums, 1.0, atol=1e-2)
|
| 28 |
+
|
| 29 |
+
if not is_prob_dist:
|
| 30 |
+
print("🔁 Beta is not normalized — applying softmax across words per topic.")
|
| 31 |
+
beta = softmax(beta / 1e-3, axis=2)
|
| 32 |
+
|
| 33 |
+
# --- 2. Now extract normalized topic slice ---
|
| 34 |
+
topic_beta = beta[:, topic_id, :] # Shape: (T, V)
|
| 35 |
+
|
| 36 |
+
# Mean and Peak probability within the topic for each word
|
| 37 |
+
mean_topic = topic_beta.mean(axis=0) # Shape: (V,)
|
| 38 |
+
peak_topic = topic_beta.max(axis=0) # Shape: (V,)
|
| 39 |
+
|
| 40 |
+
# Corpus-wide mean for baseline comparison
|
| 41 |
+
mean_all = beta.mean(axis=(0, 1)) # Shape: (V,)
|
| 42 |
+
|
| 43 |
+
# Epsilon to prevent division by zero for words that never appear
|
| 44 |
+
epsilon = 1e-9
|
| 45 |
+
|
| 46 |
+
# --- 3. Calculate the three core components of the new score ---
|
| 47 |
+
|
| 48 |
+
# a) Burstiness Score: How much a word's peak stands out from its own average.
|
| 49 |
+
# This is the key to finding "surprising" words.
|
| 50 |
+
burstiness_score = peak_topic / (mean_topic + epsilon)
|
| 51 |
+
|
| 52 |
+
# b) Peak Specificity: How much the word's peak in this topic stands out from
|
| 53 |
+
# its average presence in the entire corpus.
|
| 54 |
+
peak_specificity_score = peak_topic / (mean_all + epsilon)
|
| 55 |
+
|
| 56 |
+
# c) Uniqueness Score (same as before): Penalizes words active in many topics.
|
| 57 |
+
active_in_topics = (beta > 1e-5).mean(axis=0) # Shape: (K, V)
|
| 58 |
+
idf_like = np.log((K + 1) / (active_in_topics.sum(axis=0) + 1)) # Shape: (V,)
|
| 59 |
+
|
| 60 |
+
# --- 4. Compute Final Interestingness Score ---
|
| 61 |
+
# This score is high for words that are unique, have a high peak relative
|
| 62 |
+
# to their baseline, and whose peak is an unusual event for that word.
|
| 63 |
+
final_scores = burstiness_score * peak_specificity_score * idf_like
|
| 64 |
+
|
| 65 |
+
# --- 5. Rank and select top words ---
|
| 66 |
+
if restrict_to is not None:
|
| 67 |
+
restrict_set = set(restrict_to)
|
| 68 |
+
word_indices = [i for i, w in enumerate(vocab) if w in restrict_set]
|
| 69 |
+
else:
|
| 70 |
+
word_indices = np.arange(V)
|
| 71 |
+
|
| 72 |
+
if not word_indices:
|
| 73 |
+
return []
|
| 74 |
+
|
| 75 |
+
# Rank the filtered indices by the final score in descending order
|
| 76 |
+
sorted_indices = sorted(word_indices, key=lambda i: -final_scores[i])
|
| 77 |
+
|
| 78 |
+
return [vocab[i] for i in sorted_indices[:top_k_final]]
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def get_word_trend(beta, vocab, word, topic_id):
|
| 82 |
+
"""
|
| 83 |
+
Get the time trend of a word's probability under a specific topic.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
beta: np.ndarray of shape (T, K, V)
|
| 87 |
+
vocab: list of vocab words
|
| 88 |
+
word: word to search
|
| 89 |
+
topic_id: index of topic to inspect (0 <= topic_id < K)
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
List of word probabilities over time (length T)
|
| 93 |
+
"""
|
| 94 |
+
T, K, V = beta.shape
|
| 95 |
+
if word not in vocab:
|
| 96 |
+
raise ValueError(f"Word '{word}' not found in vocab.")
|
| 97 |
+
if not (0 <= topic_id < K):
|
| 98 |
+
raise ValueError(f"Invalid topic_id {topic_id}. Must be between 0 and {K - 1}.")
|
| 99 |
+
|
| 100 |
+
word_index = vocab.index(word)
|
| 101 |
+
trend = beta[:, topic_id, word_index] # shape (T,)
|
| 102 |
+
return trend.tolist()
|
backend/llm/custom_gemini.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 2 |
+
from langchain_core.messages import AIMessage, HumanMessage
|
| 3 |
+
from langchain_core.language_models.chat_models import BaseChatModel
|
| 4 |
+
from typing import List
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ChatGemini(BaseChatModel):
|
| 8 |
+
def __init__(self, api_key: str, model: str = "gemini-pro", temperature: float = 0.7):
|
| 9 |
+
self.model = model
|
| 10 |
+
self.temperature = temperature
|
| 11 |
+
self.api_key = api_key
|
| 12 |
+
self.client = ChatGoogleGenerativeAI(
|
| 13 |
+
model=model,
|
| 14 |
+
temperature=temperature,
|
| 15 |
+
google_api_key=api_key
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
def _generate(self, messages: List, stop: List[str] = None):
|
| 19 |
+
# Convert LangChain messages to string
|
| 20 |
+
prompt = "\n".join(
|
| 21 |
+
msg.content for msg in messages if isinstance(msg, (HumanMessage, AIMessage))
|
| 22 |
+
)
|
| 23 |
+
response = self.client.invoke(prompt)
|
| 24 |
+
return response
|
| 25 |
+
|
| 26 |
+
@property
|
| 27 |
+
def _llm_type(self) -> str:
|
| 28 |
+
return "gemini"
|
backend/llm/custom_mistral.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_core.language_models.chat_models import BaseChatModel
|
| 2 |
+
from langchain_core.messages import HumanMessage, AIMessage
|
| 3 |
+
from langchain_core.outputs import ChatResult, ChatGeneration
|
| 4 |
+
import requests
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
class ChatMistral(BaseChatModel):
|
| 8 |
+
def __init__(self, hf_token=None, model_url=None):
|
| 9 |
+
self.hf_token = hf_token or os.getenv("HF_TOKEN")
|
| 10 |
+
self.model_url = model_url or "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.1"
|
| 11 |
+
self.headers = {"Authorization": f"Bearer {self.hf_token}"}
|
| 12 |
+
|
| 13 |
+
def _call(self, prompt: str) -> str:
|
| 14 |
+
response = requests.post(
|
| 15 |
+
self.model_url,
|
| 16 |
+
headers=self.headers,
|
| 17 |
+
json={"inputs": prompt, "parameters": {"max_new_tokens": 256}},
|
| 18 |
+
)
|
| 19 |
+
return response.json()[0]["generated_text"]
|
| 20 |
+
|
| 21 |
+
def invoke(self, messages, **kwargs):
|
| 22 |
+
prompt = "\n".join([msg.content for msg in messages if isinstance(msg, HumanMessage)])
|
| 23 |
+
response = self._call(prompt)
|
| 24 |
+
return AIMessage(content=response)
|
| 25 |
+
|
| 26 |
+
def _generate(self, messages, stop=None, **kwargs) -> ChatResult:
|
| 27 |
+
return ChatResult(generations=[ChatGeneration(message=self.invoke(messages))])
|
backend/llm/llm_router.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_anthropic import ChatAnthropic
|
| 2 |
+
from backend.llm.custom_mistral import ChatMistral
|
| 3 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 4 |
+
from langchain_openai import ChatOpenAI
|
| 5 |
+
import os
|
| 6 |
+
import google.auth.transport.requests
|
| 7 |
+
import requests
|
| 8 |
+
|
| 9 |
+
resp = requests.get("https://www.google.com", proxies={
|
| 10 |
+
"http": os.getenv("http_proxy"),
|
| 11 |
+
"https": os.getenv("https_proxy")
|
| 12 |
+
})
|
| 13 |
+
|
| 14 |
+
def list_supported_models(provider=None):
|
| 15 |
+
if provider == "OpenAI":
|
| 16 |
+
return ["gpt-4.1-nano", "gpt-4o-mini"]
|
| 17 |
+
elif provider == "Anthropic":
|
| 18 |
+
return ["claude-3-opus-20240229", "claude-3-sonnet-20240229"]
|
| 19 |
+
elif provider == "Gemini":
|
| 20 |
+
return ["gemini-2.0-flash-lite", "gemini-1.5-flash"]
|
| 21 |
+
elif provider == "Mistral":
|
| 22 |
+
return ["mistral-small", "mistral-medium"]
|
| 23 |
+
else:
|
| 24 |
+
# Default fallback: all models grouped by provider
|
| 25 |
+
return {
|
| 26 |
+
"OpenAI": ["gpt-4.1-nano", "gpt-4o-mini"],
|
| 27 |
+
"Anthropic": ["claude-3-opus-20240229", "claude-3-sonnet-20240229"],
|
| 28 |
+
"Gemini": ["gemini-2.0-flash-lite", "gemini-1.5-flash"],
|
| 29 |
+
"Mistral": ["mistral-small", "mistral-medium"]
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_llm(provider: str, model: str, api_key: str = None):
|
| 34 |
+
if provider == "OpenAI":
|
| 35 |
+
api_key = api_key or os.getenv("OPENAI_API_KEY")
|
| 36 |
+
if not api_key:
|
| 37 |
+
raise ValueError("Missing OpenAI API key.")
|
| 38 |
+
return ChatOpenAI(model_name=model, temperature=0, openai_api_key=api_key)
|
| 39 |
+
|
| 40 |
+
elif provider == "Anthropic":
|
| 41 |
+
api_key = api_key or os.getenv("ANTHROPIC_API_KEY")
|
| 42 |
+
if not api_key:
|
| 43 |
+
raise ValueError("Missing Anthropic API key.")
|
| 44 |
+
return ChatAnthropic(model=model, temperature=0, anthropic_api_key=api_key)
|
| 45 |
+
|
| 46 |
+
elif provider == "Gemini":
|
| 47 |
+
api_key = api_key or os.getenv("GEMINI_API_KEY")
|
| 48 |
+
if not api_key:
|
| 49 |
+
raise ValueError("Missing Gemini API key.")
|
| 50 |
+
# --- Patch: Set proxy if available ---
|
| 51 |
+
if "HTTP_PROXY" in os.environ or "http_proxy" in os.environ:
|
| 52 |
+
|
| 53 |
+
proxies = {
|
| 54 |
+
"http": os.getenv("http_proxy") or os.getenv("HTTP_PROXY"),
|
| 55 |
+
"https": os.getenv("https_proxy") or os.getenv("HTTPS_PROXY")
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
google.auth.transport.requests.requests.Request = lambda *args, **kwargs: requests.Request(
|
| 59 |
+
*args, **kwargs, proxies=proxies
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
return ChatGoogleGenerativeAI(model=model, temperature=0, google_api_key=api_key)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
elif provider == "Mistral":
|
| 66 |
+
api_key = api_key or os.getenv("MISTRAL_API_KEY")
|
| 67 |
+
if not api_key:
|
| 68 |
+
raise ValueError("Missing Mistral API key.")
|
| 69 |
+
return ChatMistral(model=model, temperature=0, mistral_api_key=api_key)
|
| 70 |
+
|
| 71 |
+
else:
|
| 72 |
+
raise ValueError(f"Unsupported provider: {provider}")
|
| 73 |
+
|
backend/llm_utils/label_generator.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from hashlib import sha256
|
| 2 |
+
import json
|
| 3 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 4 |
+
from langchain_core.output_parsers import StrOutputParser
|
| 5 |
+
from typing import Optional
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
#get_top_words_at_time
|
| 9 |
+
from backend.inference.process_beta import get_top_words_at_time
|
| 10 |
+
|
| 11 |
+
def label_topic_temporal(word_trajectory_str: str, llm, cache_path: Optional[str] = None) -> str:
|
| 12 |
+
"""
|
| 13 |
+
Label a dynamic topic by providing the LLM with the top words over time.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
word_trajectory_str (str): Formatted keyword evolution string.
|
| 17 |
+
llm: LangChain-compatible LLM instance.
|
| 18 |
+
cache_path (Optional[str]): Path to the cache file (JSON).
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
str: Short label for the topic.
|
| 22 |
+
"""
|
| 23 |
+
topic_key = sha256(word_trajectory_str.encode()).hexdigest()
|
| 24 |
+
|
| 25 |
+
# Load cache
|
| 26 |
+
if cache_path is not None and os.path.exists(cache_path):
|
| 27 |
+
with open(cache_path, "r") as f:
|
| 28 |
+
label_cache = json.load(f)
|
| 29 |
+
else:
|
| 30 |
+
label_cache = {}
|
| 31 |
+
|
| 32 |
+
# Return cached result
|
| 33 |
+
if topic_key in label_cache:
|
| 34 |
+
return label_cache[topic_key]
|
| 35 |
+
|
| 36 |
+
# Prompt template
|
| 37 |
+
prompt = ChatPromptTemplate.from_template(
|
| 38 |
+
"You are an expert in topic modeling and temporal data analysis. "
|
| 39 |
+
"Given the top words for a topic across multiple time points, your task is to return a short, specific, descriptive topic label. "
|
| 40 |
+
"Avoid vague, generic, or overly broad labels. Focus on consistent themes in the top words over time. "
|
| 41 |
+
"Use concise noun phrases, 2–5 words max. Do NOT include any explanation, justification, or extra output.\n\n"
|
| 42 |
+
"Top words over time:\n{trajectory}\n\n"
|
| 43 |
+
"Return ONLY the label (no quotes, no extra text):"
|
| 44 |
+
)
|
| 45 |
+
chain = prompt | llm | StrOutputParser()
|
| 46 |
+
|
| 47 |
+
try:
|
| 48 |
+
label = chain.invoke({"trajectory": word_trajectory_str}).strip()
|
| 49 |
+
except Exception as e:
|
| 50 |
+
label = "Unknown Topic"
|
| 51 |
+
print(f"[Labeling Error] {e}")
|
| 52 |
+
|
| 53 |
+
# Update cache and save
|
| 54 |
+
label_cache[topic_key] = label
|
| 55 |
+
if cache_path is not None:
|
| 56 |
+
os.makedirs(os.path.dirname(cache_path), exist_ok=True)
|
| 57 |
+
with open(cache_path, "w") as f:
|
| 58 |
+
json.dump(label_cache, f, indent=2)
|
| 59 |
+
|
| 60 |
+
return label
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def get_topic_labels(beta, vocab, time_labels, llm, cache_path):
|
| 64 |
+
topic_labels = {}
|
| 65 |
+
for topic_id in range(beta.shape[1]):
|
| 66 |
+
word_trajectory_str = "\n".join([
|
| 67 |
+
f"{time_labels[t]}: {', '.join(get_top_words_at_time(beta, vocab, topic_id, t, top_n=10))}"
|
| 68 |
+
for t in range(beta.shape[0])
|
| 69 |
+
])
|
| 70 |
+
label = label_topic_temporal(word_trajectory_str, llm=llm, cache_path=cache_path)
|
| 71 |
+
topic_labels[topic_id] = label
|
| 72 |
+
return topic_labels
|
backend/llm_utils/summarizer.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import hashlib
|
| 2 |
+
import numpy as np
|
| 3 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 4 |
+
from sentence_transformers import SentenceTransformer
|
| 5 |
+
import faiss
|
| 6 |
+
|
| 7 |
+
from langchain.prompts import ChatPromptTemplate
|
| 8 |
+
from langchain.docstore.document import Document
|
| 9 |
+
from langchain.memory import ConversationBufferMemory
|
| 10 |
+
from langchain.chains import ConversationChain
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# --- MMR Utilities ---
|
| 17 |
+
def build_mmr_index(docs):
|
| 18 |
+
texts = [doc['text'] for doc in docs if 'text' in doc]
|
| 19 |
+
documents = [Document(page_content=text) for text in texts]
|
| 20 |
+
|
| 21 |
+
model = SentenceTransformer("all-MiniLM-L6-v2")
|
| 22 |
+
embeddings = model.encode([doc.page_content for doc in documents], convert_to_numpy=True)
|
| 23 |
+
faiss.normalize_L2(embeddings)
|
| 24 |
+
|
| 25 |
+
index = faiss.IndexFlatIP(embeddings.shape[1])
|
| 26 |
+
index.add(embeddings)
|
| 27 |
+
|
| 28 |
+
return model, index, embeddings, documents
|
| 29 |
+
|
| 30 |
+
def get_mmr_sample(model, index, embeddings, documents, query, k=15, lambda_mult=0.7):
|
| 31 |
+
if len(documents) == 0:
|
| 32 |
+
print("Warning: No documents available, returning empty list.")
|
| 33 |
+
return []
|
| 34 |
+
|
| 35 |
+
if len(documents) <= k:
|
| 36 |
+
print(f"Warning: Only {len(documents)} documents available, returning all.")
|
| 37 |
+
return documents
|
| 38 |
+
|
| 39 |
+
else:
|
| 40 |
+
query_vec = model.encode(query, convert_to_numpy=True)
|
| 41 |
+
query_vec = query_vec / np.linalg.norm(query_vec)
|
| 42 |
+
|
| 43 |
+
# Get candidate indices from FAISS (k * 4 or less if not enough documents)
|
| 44 |
+
num_candidates = min(k * 4, len(documents))
|
| 45 |
+
D, I = index.search(np.expand_dims(query_vec, axis=0), num_candidates)
|
| 46 |
+
candidate_idxs = list(I[0])
|
| 47 |
+
|
| 48 |
+
selected = []
|
| 49 |
+
while len(selected) < k and candidate_idxs:
|
| 50 |
+
if not selected:
|
| 51 |
+
selected.append(candidate_idxs.pop(0))
|
| 52 |
+
continue
|
| 53 |
+
|
| 54 |
+
mmr_scores = []
|
| 55 |
+
for idx in candidate_idxs:
|
| 56 |
+
relevance = cosine_similarity([query_vec], [embeddings[idx]])[0][0]
|
| 57 |
+
diversity = max([
|
| 58 |
+
cosine_similarity([embeddings[idx]], [embeddings[sel]])[0][0]
|
| 59 |
+
for sel in selected
|
| 60 |
+
])
|
| 61 |
+
mmr_score = lambda_mult * relevance - (1 - lambda_mult) * diversity
|
| 62 |
+
mmr_scores.append((idx, mmr_score))
|
| 63 |
+
|
| 64 |
+
next_best = max(mmr_scores, key=lambda x: x[1])[0]
|
| 65 |
+
selected.append(next_best)
|
| 66 |
+
candidate_idxs.remove(next_best)
|
| 67 |
+
|
| 68 |
+
return [documents[i] for i in selected]
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# --- Summarization ---
|
| 72 |
+
def summarize_docs(word, timestamp, docs, llm, k):
|
| 73 |
+
if not docs:
|
| 74 |
+
return "No documents available for this word at this time.", [], 0
|
| 75 |
+
|
| 76 |
+
try:
|
| 77 |
+
model, index, embeddings, documents = build_mmr_index(docs)
|
| 78 |
+
mmr_docs = get_mmr_sample(model, index, embeddings, documents, query=word, k=k)
|
| 79 |
+
|
| 80 |
+
context_texts = "\n".join(f"- {doc.page_content}" for doc in mmr_docs)
|
| 81 |
+
|
| 82 |
+
prompt_template = ChatPromptTemplate.from_template(
|
| 83 |
+
"Given the following documents from {timestamp} containing the word '{word}', "
|
| 84 |
+
"identify the key themes or distinct discussion points that were prevalent during that time. "
|
| 85 |
+
"Do NOT describe each bullet in detail. Be concise. Each bullet should be a short phrase or sentence "
|
| 86 |
+
"capturing a unique, non-overlapping theme. Avoid any elaboration, examples, or justification.\n\n"
|
| 87 |
+
"Return no more than 5–7 bullets.\n\n"
|
| 88 |
+
"{context_texts}\n\nSummary:"
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
chain = prompt_template | llm
|
| 92 |
+
summary = chain.invoke({
|
| 93 |
+
"word": word,
|
| 94 |
+
"timestamp": timestamp,
|
| 95 |
+
"context_texts": context_texts
|
| 96 |
+
}).content.strip()
|
| 97 |
+
|
| 98 |
+
return summary, mmr_docs
|
| 99 |
+
|
| 100 |
+
except Exception as e:
|
| 101 |
+
return f"[Error summarizing: {e}]", [], 0
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def summarize_multiword_docs(words, timestamp, docs, llm, k):
|
| 105 |
+
if not docs:
|
| 106 |
+
return "No common documents available for these words at this time.", []
|
| 107 |
+
|
| 108 |
+
try:
|
| 109 |
+
model, index, embeddings, documents = build_mmr_index(docs)
|
| 110 |
+
query = " ".join(words)
|
| 111 |
+
mmr_docs = get_mmr_sample(model, index, embeddings, documents, query=query, k=k)
|
| 112 |
+
|
| 113 |
+
context_texts = "\n".join(f"- {doc.page_content}" for doc in mmr_docs)
|
| 114 |
+
|
| 115 |
+
prompt_template = ChatPromptTemplate.from_template(
|
| 116 |
+
"Given the following documents from {timestamp} that all mention the words: '{word_list}', "
|
| 117 |
+
"identify the key themes or distinct discussion points that were prevalent during that time. "
|
| 118 |
+
"Do NOT describe each bullet in detail. Be concise. Each bullet should be a short phrase or sentence "
|
| 119 |
+
"capturing a unique, non-overlapping theme. Avoid any elaboration, examples, or justification.\n\n"
|
| 120 |
+
"Return no more than 5–7 bullets.\n\n"
|
| 121 |
+
"{context_texts}\n\n"
|
| 122 |
+
"Concise Thematic Summary:"
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
chain = prompt_template | llm
|
| 126 |
+
summary = chain.invoke({
|
| 127 |
+
"word_list": ", ".join(words),
|
| 128 |
+
"timestamp": timestamp,
|
| 129 |
+
"context_texts": context_texts
|
| 130 |
+
}).content.strip()
|
| 131 |
+
|
| 132 |
+
return summary, mmr_docs
|
| 133 |
+
|
| 134 |
+
except Exception as e:
|
| 135 |
+
return f"[Error summarizing: {e}]", []
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
# --- Follow-up Question Handler (Improved) ---
|
| 139 |
+
def ask_multiturn_followup(history: list, question: str, llm, context_texts: str) -> str:
|
| 140 |
+
"""
|
| 141 |
+
Handles multi-turn follow-up questions based on a provided set of documents.
|
| 142 |
+
|
| 143 |
+
This function now REQUIRES context_texts to be provided, ensuring the LLM
|
| 144 |
+
is always grounded in the source documents for follow-up questions.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
history (list): A list of dictionaries representing the conversation history
|
| 148 |
+
(e.g., [{"role": "user", "content": "..."}]).
|
| 149 |
+
question (str): The user's new follow-up question.
|
| 150 |
+
llm: The initialized language model instance.
|
| 151 |
+
context_texts (str): A single string containing all the numbered documents
|
| 152 |
+
for context.
|
| 153 |
+
|
| 154 |
+
Returns:
|
| 155 |
+
str: The AI's response to the follow-up question.
|
| 156 |
+
"""
|
| 157 |
+
try:
|
| 158 |
+
# 1. Reconstruct conversation memory from the history provided from the UI
|
| 159 |
+
memory = ConversationBufferMemory(return_messages=True)
|
| 160 |
+
for turn in history:
|
| 161 |
+
if turn["role"] == "user":
|
| 162 |
+
memory.chat_memory.add_user_message(turn["content"])
|
| 163 |
+
elif turn["role"] == "assistant":
|
| 164 |
+
memory.chat_memory.add_ai_message(turn["content"])
|
| 165 |
+
|
| 166 |
+
# 2. Define the system instruction that grounds the LLM
|
| 167 |
+
system_instruction = (
|
| 168 |
+
"You are an assistant answering questions strictly based on the provided sample documents below. "
|
| 169 |
+
"Your memory contains the previous turns of this conversation. "
|
| 170 |
+
"If the answer is not clearly available in the text, respond with: "
|
| 171 |
+
"'The information is not available in the documents provided.'\n\n"
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
# 3. Create the full prompt. No more conditional logic, as context is required.
|
| 175 |
+
# The `ConversationChain` will automatically use the memory, so we only need
|
| 176 |
+
# to provide the current input, which includes the grounding documents.
|
| 177 |
+
full_prompt = (
|
| 178 |
+
f"{system_instruction}"
|
| 179 |
+
f"--- DOCUMENTS ---\n{context_texts.strip()}\n\n"
|
| 180 |
+
f"--- QUESTION ---\n{question}"
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
# 4. Create and run the conversation chain
|
| 184 |
+
conversation = ConversationChain(llm=llm, memory=memory, verbose=False)
|
| 185 |
+
response = conversation.predict(input=full_prompt)
|
| 186 |
+
|
| 187 |
+
return response.strip()
|
| 188 |
+
|
| 189 |
+
except Exception as e:
|
| 190 |
+
# Good practice to log the full exception for easier debugging
|
| 191 |
+
print(f"[ERROR] in ask_multiturn_followup: {e}")
|
| 192 |
+
return f"[Error during multi-turn follow-up. Please check the logs.]"
|
backend/llm_utils/token_utils.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Literal
|
| 2 |
+
import tiktoken
|
| 3 |
+
import anthropic
|
| 4 |
+
from typing import List
|
| 5 |
+
|
| 6 |
+
# Gemini requires the Vertex AI SDK
|
| 7 |
+
try:
|
| 8 |
+
from vertexai.preview import tokenization as vertex_tokenization
|
| 9 |
+
except ImportError:
|
| 10 |
+
vertex_tokenization = None
|
| 11 |
+
|
| 12 |
+
# Mistral requires the SentencePiece tokenizer
|
| 13 |
+
try:
|
| 14 |
+
import sentencepiece as spm
|
| 15 |
+
except ImportError:
|
| 16 |
+
spm = None
|
| 17 |
+
|
| 18 |
+
# ---------------------------
|
| 19 |
+
# Individual Token Counters
|
| 20 |
+
# ---------------------------
|
| 21 |
+
|
| 22 |
+
def count_tokens_openai(text: str, model_name: str) -> int:
|
| 23 |
+
try:
|
| 24 |
+
encoding = tiktoken.encoding_for_model(model_name)
|
| 25 |
+
except KeyError:
|
| 26 |
+
encoding = tiktoken.get_encoding("cl100k_base") # fallback
|
| 27 |
+
return len(encoding.encode(text))
|
| 28 |
+
|
| 29 |
+
def count_tokens_anthropic(text: str, model_name: str) -> int:
|
| 30 |
+
try:
|
| 31 |
+
client = anthropic.Anthropic()
|
| 32 |
+
response = client.messages.count_tokens(
|
| 33 |
+
model=model_name,
|
| 34 |
+
messages=[{"role": "user", "content": text}]
|
| 35 |
+
)
|
| 36 |
+
return response['input_tokens']
|
| 37 |
+
except Exception as e:
|
| 38 |
+
raise RuntimeError(f"Anthropic token counting failed: {e}")
|
| 39 |
+
|
| 40 |
+
def count_tokens_gemini(text: str, model_name: str) -> int:
|
| 41 |
+
if vertex_tokenization is None:
|
| 42 |
+
raise ImportError("Please install vertexai: pip install google-cloud-aiplatform[tokenization]")
|
| 43 |
+
try:
|
| 44 |
+
tokenizer = vertex_tokenization.get_tokenizer_for_model("gemini-1.5-flash-002")
|
| 45 |
+
result = tokenizer.count_tokens(text)
|
| 46 |
+
return result.total_tokens
|
| 47 |
+
except Exception as e:
|
| 48 |
+
raise RuntimeError(f"Gemini token counting failed: {e}")
|
| 49 |
+
|
| 50 |
+
def count_tokens_mistral(text: str) -> int:
|
| 51 |
+
if spm is None:
|
| 52 |
+
raise ImportError("Please install sentencepiece: pip install sentencepiece")
|
| 53 |
+
try:
|
| 54 |
+
sp = spm.SentencePieceProcessor()
|
| 55 |
+
# IMPORTANT: You must provide the correct path to the tokenizer model file
|
| 56 |
+
sp.load("mistral_tokenizer.model")
|
| 57 |
+
tokens = sp.encode(text, out_type=str)
|
| 58 |
+
return len(tokens)
|
| 59 |
+
except Exception as e:
|
| 60 |
+
raise RuntimeError(f"Mistral token counting failed: {e}")
|
| 61 |
+
|
| 62 |
+
# ---------------------------
|
| 63 |
+
# Unified Token Counter
|
| 64 |
+
# ---------------------------
|
| 65 |
+
|
| 66 |
+
def count_tokens(text: str, model_name: str, provider: Literal["OpenAI", "Anthropic", "Gemini", "Mistral"]) -> int:
|
| 67 |
+
if provider == "OpenAI":
|
| 68 |
+
return count_tokens_openai(text, model_name)
|
| 69 |
+
elif provider == "Anthropic":
|
| 70 |
+
return count_tokens_anthropic(text, model_name)
|
| 71 |
+
elif provider == "Gemini":
|
| 72 |
+
return count_tokens_gemini(text, model_name)
|
| 73 |
+
elif provider == "Mistral":
|
| 74 |
+
return count_tokens_mistral(text)
|
| 75 |
+
else:
|
| 76 |
+
raise ValueError(f"Unsupported provider: {provider}")
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def get_token_limit_for_model(model_name, provider):
|
| 80 |
+
# Example values; update as needed for your providers
|
| 81 |
+
if provider == "openai":
|
| 82 |
+
if "gpt-4.1-nano" in model_name:
|
| 83 |
+
return 1047576 # Based on search results
|
| 84 |
+
elif "gpt-4o-mini" in model_name:
|
| 85 |
+
return 128000 # Based on search results
|
| 86 |
+
elif provider == "anthropic":
|
| 87 |
+
if "claude-3-opus" in model_name:
|
| 88 |
+
return 200000 # Based on search results
|
| 89 |
+
elif "claude-3-sonnet" in model_name:
|
| 90 |
+
return 200000 # Based on search results
|
| 91 |
+
elif provider == "gemini":
|
| 92 |
+
if "gemini-2.0-flash-lite" in model_name:
|
| 93 |
+
return 1048576 # Based on search results
|
| 94 |
+
elif "gemini-1.5-flash" in model_name:
|
| 95 |
+
return 1048576 # Based on search results
|
| 96 |
+
elif provider == "mistral":
|
| 97 |
+
if "mistral-small" in model_name:
|
| 98 |
+
return 32000 # Based on search results
|
| 99 |
+
elif "mistral-medium" in model_name:
|
| 100 |
+
return 32000 # Based on search results
|
| 101 |
+
return 8000 # default fallback
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def estimate_avg_tokens_per_doc(
|
| 105 |
+
docs: List[str],
|
| 106 |
+
model_name: str,
|
| 107 |
+
provider: Literal["OpenAI", "Anthropic", "Gemini", "Mistral"]
|
| 108 |
+
) -> float:
|
| 109 |
+
"""
|
| 110 |
+
Estimate the average number of tokens per document for the given model.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
docs (List[str]): List of documents.
|
| 114 |
+
model_name (str): Model name.
|
| 115 |
+
provider (Literal): LLM provider.
|
| 116 |
+
|
| 117 |
+
Returns:
|
| 118 |
+
float: Average number of tokens per document.
|
| 119 |
+
"""
|
| 120 |
+
if not docs:
|
| 121 |
+
return 0.0
|
| 122 |
+
token_counts = [count_tokens(doc, model_name, provider) for doc in docs]
|
| 123 |
+
return sum(token_counts) / len(token_counts)
|
| 124 |
+
|
| 125 |
+
def estimate_max_k(
|
| 126 |
+
docs: List[str],
|
| 127 |
+
model_name: str,
|
| 128 |
+
provider: Literal["OpenAI", "Anthropic", "Gemini", "Mistral"],
|
| 129 |
+
margin_ratio: float = 0.1,
|
| 130 |
+
) -> int:
|
| 131 |
+
"""
|
| 132 |
+
Estimate the maximum number of documents that can fit in the context window.
|
| 133 |
+
|
| 134 |
+
Returns:
|
| 135 |
+
int: Estimated K.
|
| 136 |
+
"""
|
| 137 |
+
if not docs:
|
| 138 |
+
return 0
|
| 139 |
+
|
| 140 |
+
max_tokens = get_token_limit_for_model(model_name, provider)
|
| 141 |
+
margin = int(max_tokens * margin_ratio)
|
| 142 |
+
available_tokens = max_tokens - margin
|
| 143 |
+
|
| 144 |
+
avg_tokens_per_doc = estimate_avg_tokens_per_doc(docs, model_name, provider)
|
| 145 |
+
if avg_tokens_per_doc == 0:
|
| 146 |
+
return 0
|
| 147 |
+
|
| 148 |
+
return min(len(docs), int(available_tokens // avg_tokens_per_doc))
|
| 149 |
+
|
| 150 |
+
def estimate_max_k_fast(docs, margin_ratio=0.1, max_tokens=8000, model_name="gpt-3.5-turbo"):
|
| 151 |
+
enc = tiktoken.encoding_for_model(model_name)
|
| 152 |
+
avg_len = sum(len(enc.encode(doc)) for doc in docs[:20]) / min(20, len(docs))
|
| 153 |
+
margin = int(max_tokens * margin_ratio)
|
| 154 |
+
available = max_tokens - margin
|
| 155 |
+
return min(len(docs), int(available // avg_len))
|
| 156 |
+
|
| 157 |
+
def estimate_k_max_from_word_stats(
|
| 158 |
+
avg_words_per_doc: float,
|
| 159 |
+
margin_ratio: float = 0.1,
|
| 160 |
+
avg_tokens_per_word: float = 1.3,
|
| 161 |
+
model_name=None,
|
| 162 |
+
provider=None
|
| 163 |
+
) -> int:
|
| 164 |
+
model_token_limit = get_token_limit_for_model(model_name, provider)
|
| 165 |
+
effective_limit = int(model_token_limit * (1 - margin_ratio))
|
| 166 |
+
est_tokens_per_doc = avg_words_per_doc * avg_tokens_per_word
|
| 167 |
+
return int(effective_limit // est_tokens_per_doc)
|
backend/models/CFDTM/CFDTM.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
from .ETC import ETC
|
| 7 |
+
from .UWE import UWE
|
| 8 |
+
from .Encoder import MLPEncoder
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class CFDTM(nn.Module):
|
| 12 |
+
'''
|
| 13 |
+
Modeling Dynamic Topics in Chain-Free Fashion by Evolution-Tracking Contrastive Learning and Unassociated Word Exclusion. ACL 2024 Findings
|
| 14 |
+
|
| 15 |
+
Xiaobao Wu, Xinshuai Dong, Liangming Pan, Thong Nguyen, Anh Tuan Luu.
|
| 16 |
+
'''
|
| 17 |
+
|
| 18 |
+
def __init__(self,
|
| 19 |
+
vocab_size,
|
| 20 |
+
train_time_wordfreq,
|
| 21 |
+
num_times,
|
| 22 |
+
pretrained_WE=None,
|
| 23 |
+
num_topics=50,
|
| 24 |
+
en_units=100,
|
| 25 |
+
temperature=0.1,
|
| 26 |
+
beta_temp=1.0,
|
| 27 |
+
weight_neg=1.0e+7,
|
| 28 |
+
weight_pos=1.0e+1,
|
| 29 |
+
weight_UWE=1.0e+3,
|
| 30 |
+
neg_topk=15,
|
| 31 |
+
dropout=0.,
|
| 32 |
+
embed_size=200
|
| 33 |
+
):
|
| 34 |
+
super().__init__()
|
| 35 |
+
|
| 36 |
+
self.num_topics = num_topics
|
| 37 |
+
self.beta_temp = beta_temp
|
| 38 |
+
self.train_time_wordfreq = train_time_wordfreq
|
| 39 |
+
self.encoder = MLPEncoder(vocab_size, num_topics, en_units, dropout)
|
| 40 |
+
|
| 41 |
+
self.a = 1 * np.ones((1, num_topics)).astype(np.float32)
|
| 42 |
+
self.mu2 = nn.Parameter(torch.as_tensor((np.log(self.a).T - np.mean(np.log(self.a), 1)).T))
|
| 43 |
+
self.var2 = nn.Parameter(torch.as_tensor((((1.0 / self.a) * (1 - (2.0 / num_topics))).T + (1.0 / (num_topics * num_topics)) * np.sum(1.0 / self.a, 1)).T))
|
| 44 |
+
|
| 45 |
+
self.mu2.requires_grad = False
|
| 46 |
+
self.var2.requires_grad = False
|
| 47 |
+
|
| 48 |
+
self.decoder_bn = nn.BatchNorm1d(vocab_size, affine=False)
|
| 49 |
+
|
| 50 |
+
if pretrained_WE is None:
|
| 51 |
+
self.word_embeddings = nn.init.trunc_normal_(torch.empty(vocab_size, embed_size), std=0.1)
|
| 52 |
+
self.word_embeddings = nn.Parameter(F.normalize(self.word_embeddings))
|
| 53 |
+
|
| 54 |
+
else:
|
| 55 |
+
self.word_embeddings = nn.Parameter(torch.from_numpy(pretrained_WE).float())
|
| 56 |
+
|
| 57 |
+
# topic_embeddings: TxKxD
|
| 58 |
+
self.topic_embeddings = nn.init.xavier_normal_(torch.zeros(num_topics, self.word_embeddings.shape[1])).repeat(num_times, 1, 1)
|
| 59 |
+
self.topic_embeddings = nn.Parameter(self.topic_embeddings)
|
| 60 |
+
|
| 61 |
+
self.ETC = ETC(num_times, temperature, weight_neg, weight_pos)
|
| 62 |
+
self.UWE = UWE(self.ETC, num_times, temperature, weight_UWE, neg_topk)
|
| 63 |
+
|
| 64 |
+
def get_beta(self):
|
| 65 |
+
dist = self.pairwise_euclidean_dist(F.normalize(self.topic_embeddings, dim=-1), F.normalize(self.word_embeddings, dim=-1))
|
| 66 |
+
beta = F.softmax(-dist / self.beta_temp, dim=1)
|
| 67 |
+
|
| 68 |
+
return beta
|
| 69 |
+
|
| 70 |
+
def pairwise_euclidean_dist(self, x, y):
|
| 71 |
+
cost = torch.sum(x ** 2, axis=-1, keepdim=True) + torch.sum(y ** 2, axis=-1) - 2 * torch.matmul(x, y.t())
|
| 72 |
+
return cost
|
| 73 |
+
|
| 74 |
+
def get_theta(self, x, times=None):
|
| 75 |
+
theta, mu, logvar = self.encoder(x)
|
| 76 |
+
if self.training:
|
| 77 |
+
return theta, mu, logvar
|
| 78 |
+
|
| 79 |
+
return theta
|
| 80 |
+
|
| 81 |
+
def get_KL(self, mu, logvar):
|
| 82 |
+
var = logvar.exp()
|
| 83 |
+
var_division = var / self.var2
|
| 84 |
+
diff = mu - self.mu2
|
| 85 |
+
diff_term = diff * diff / self.var2
|
| 86 |
+
logvar_division = self.var2.log() - logvar
|
| 87 |
+
KLD = 0.5 * ((var_division + diff_term + logvar_division).sum(axis=1) - self.num_topics)
|
| 88 |
+
|
| 89 |
+
return KLD.mean()
|
| 90 |
+
|
| 91 |
+
def get_NLL(self, theta, beta, x, recon_x=None):
|
| 92 |
+
if recon_x is None:
|
| 93 |
+
recon_x = self.decode(theta, beta)
|
| 94 |
+
recon_loss = -(x * recon_x.log()).sum(axis=1)
|
| 95 |
+
|
| 96 |
+
return recon_loss
|
| 97 |
+
|
| 98 |
+
def decode(self, theta, beta):
|
| 99 |
+
d1 = F.softmax(self.decoder_bn(torch.bmm(theta.unsqueeze(1), beta).squeeze(1)), dim=-1)
|
| 100 |
+
return d1
|
| 101 |
+
|
| 102 |
+
def forward(self, x, times):
|
| 103 |
+
loss = 0.
|
| 104 |
+
|
| 105 |
+
theta, mu, logvar = self.get_theta(x)
|
| 106 |
+
kl_theta = self.get_KL(mu, logvar)
|
| 107 |
+
|
| 108 |
+
loss += kl_theta
|
| 109 |
+
|
| 110 |
+
beta = self.get_beta()
|
| 111 |
+
time_index_beta = beta[times]
|
| 112 |
+
recon_x = self.decode(theta, time_index_beta)
|
| 113 |
+
NLL = self.get_NLL(theta, time_index_beta, x, recon_x)
|
| 114 |
+
NLL = NLL.mean()
|
| 115 |
+
loss += NLL
|
| 116 |
+
|
| 117 |
+
loss_ETC = self.ETC(self.topic_embeddings)
|
| 118 |
+
loss += loss_ETC
|
| 119 |
+
|
| 120 |
+
loss_UWE = self.UWE(self.train_time_wordfreq, beta, self.topic_embeddings, self.word_embeddings)
|
| 121 |
+
loss += loss_UWE
|
| 122 |
+
|
| 123 |
+
rst_dict = {
|
| 124 |
+
'loss': loss,
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
return rst_dict
|
backend/models/CFDTM/ETC.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ETC(nn.Module):
|
| 7 |
+
def __init__(self, num_times, temperature, weight_neg, weight_pos):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.num_times = num_times
|
| 10 |
+
self.weight_neg = weight_neg
|
| 11 |
+
self.weight_pos = weight_pos
|
| 12 |
+
self.temperature = temperature
|
| 13 |
+
|
| 14 |
+
def forward(self, topic_embeddings):
|
| 15 |
+
loss = 0.
|
| 16 |
+
loss_neg = 0.
|
| 17 |
+
loss_pos = 0.
|
| 18 |
+
|
| 19 |
+
for t in range(self.num_times):
|
| 20 |
+
loss_neg += self.compute_loss(topic_embeddings[t], topic_embeddings[t], self.temperature, self_contrast=True)
|
| 21 |
+
|
| 22 |
+
for t in range(1, self.num_times):
|
| 23 |
+
loss_pos += self.compute_loss(topic_embeddings[t], topic_embeddings[t - 1].detach(), self.temperature, self_contrast=False, only_pos=True)
|
| 24 |
+
|
| 25 |
+
loss_neg *= (self.weight_neg / self.num_times)
|
| 26 |
+
loss_pos *= (self.weight_pos / (self.num_times - 1))
|
| 27 |
+
loss = loss_neg + loss_pos
|
| 28 |
+
|
| 29 |
+
return loss
|
| 30 |
+
|
| 31 |
+
def compute_loss(self, anchor_feature, contrast_feature, temperature, self_contrast=False, only_pos=False, all_neg=False):
|
| 32 |
+
# KxK
|
| 33 |
+
anchor_dot_contrast = torch.div(
|
| 34 |
+
torch.matmul(F.normalize(anchor_feature, dim=1), F.normalize(contrast_feature, dim=1).T),
|
| 35 |
+
temperature
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
|
| 39 |
+
logits = anchor_dot_contrast - logits_max.detach()
|
| 40 |
+
|
| 41 |
+
pos_mask = torch.eye(anchor_dot_contrast.shape[0]).to(anchor_dot_contrast.device)
|
| 42 |
+
|
| 43 |
+
if self_contrast is False:
|
| 44 |
+
if only_pos is False:
|
| 45 |
+
if all_neg is True:
|
| 46 |
+
exp_logits = torch.exp(logits)
|
| 47 |
+
sum_exp_logits = exp_logits.sum(1)
|
| 48 |
+
log_prob = -torch.log(sum_exp_logits + 1e-12)
|
| 49 |
+
|
| 50 |
+
mean_log_prob = -log_prob.sum() / (logits.shape[0] * logits.shape[1])
|
| 51 |
+
else:
|
| 52 |
+
# only pos
|
| 53 |
+
mean_log_prob = -(logits * pos_mask).sum() / pos_mask.sum()
|
| 54 |
+
else:
|
| 55 |
+
# self contrast: push away from each other in the same time slice.
|
| 56 |
+
exp_logits = torch.exp(logits) * (1 - pos_mask)
|
| 57 |
+
sum_exp_logits = exp_logits.sum(1)
|
| 58 |
+
log_prob = -torch.log(sum_exp_logits + 1e-12)
|
| 59 |
+
|
| 60 |
+
mean_log_prob = -log_prob.sum() / (1 - pos_mask).sum()
|
| 61 |
+
|
| 62 |
+
return mean_log_prob
|
backend/models/CFDTM/Encoder.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class MLPEncoder(nn.Module):
|
| 7 |
+
def __init__(self, vocab_size, num_topic, hidden_dim, dropout):
|
| 8 |
+
super().__init__()
|
| 9 |
+
|
| 10 |
+
self.fc11 = nn.Linear(vocab_size, hidden_dim)
|
| 11 |
+
self.fc12 = nn.Linear(hidden_dim, hidden_dim)
|
| 12 |
+
self.fc21 = nn.Linear(hidden_dim, num_topic)
|
| 13 |
+
self.fc22 = nn.Linear(hidden_dim, num_topic)
|
| 14 |
+
|
| 15 |
+
self.fc1_drop = nn.Dropout(dropout)
|
| 16 |
+
self.z_drop = nn.Dropout(dropout)
|
| 17 |
+
|
| 18 |
+
self.mean_bn = nn.BatchNorm1d(num_topic, affine=True)
|
| 19 |
+
self.mean_bn.weight.requires_grad = False
|
| 20 |
+
self.logvar_bn = nn.BatchNorm1d(num_topic, affine=True)
|
| 21 |
+
self.logvar_bn.weight.requires_grad = False
|
| 22 |
+
|
| 23 |
+
def reparameterize(self, mu, logvar):
|
| 24 |
+
if self.training:
|
| 25 |
+
std = torch.exp(0.5 * logvar)
|
| 26 |
+
eps = torch.randn_like(std)
|
| 27 |
+
return mu + (eps * std)
|
| 28 |
+
else:
|
| 29 |
+
return mu
|
| 30 |
+
|
| 31 |
+
def forward(self, x):
|
| 32 |
+
e1 = F.softplus(self.fc11(x))
|
| 33 |
+
e1 = F.softplus(self.fc12(e1))
|
| 34 |
+
e1 = self.fc1_drop(e1)
|
| 35 |
+
mu = self.mean_bn(self.fc21(e1))
|
| 36 |
+
logvar = self.logvar_bn(self.fc22(e1))
|
| 37 |
+
theta = self.reparameterize(mu, logvar)
|
| 38 |
+
theta = F.softmax(theta, dim=1)
|
| 39 |
+
theta = self.z_drop(theta)
|
| 40 |
+
return theta, mu, logvar
|
backend/models/CFDTM/UWE.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class UWE(nn.Module):
|
| 6 |
+
def __init__(self, ETC, num_times, temperature, weight_UWE, neg_topk):
|
| 7 |
+
super().__init__()
|
| 8 |
+
|
| 9 |
+
self.ETC = ETC
|
| 10 |
+
self.weight_UWE = weight_UWE
|
| 11 |
+
self.num_times = num_times
|
| 12 |
+
self.temperature = temperature
|
| 13 |
+
self.neg_topk = neg_topk
|
| 14 |
+
|
| 15 |
+
def forward(self, time_wordcount, beta, topic_embeddings, word_embeddings):
|
| 16 |
+
assert(self.num_times == time_wordcount.shape[0])
|
| 17 |
+
|
| 18 |
+
topk_indices = self.get_topk_indices(beta)
|
| 19 |
+
|
| 20 |
+
loss_UWE = 0.
|
| 21 |
+
cnt_valid_times = 0.
|
| 22 |
+
for t in range(self.num_times):
|
| 23 |
+
neg_idx = torch.where(time_wordcount[t] == 0)[0]
|
| 24 |
+
|
| 25 |
+
time_topk_indices = topk_indices[t]
|
| 26 |
+
neg_idx = list(set(neg_idx.cpu().tolist()).intersection(set(time_topk_indices.cpu().tolist())))
|
| 27 |
+
neg_idx = torch.tensor(neg_idx).long().to(time_wordcount.device)
|
| 28 |
+
|
| 29 |
+
if len(neg_idx) == 0:
|
| 30 |
+
continue
|
| 31 |
+
|
| 32 |
+
time_neg_WE = word_embeddings[neg_idx]
|
| 33 |
+
|
| 34 |
+
# topic_embeddings[t]: K x D
|
| 35 |
+
# word_embeddings[neg_idx]: |V_{neg}| x D
|
| 36 |
+
loss_UWE += self.ETC.compute_loss(topic_embeddings[t], time_neg_WE, temperature=self.temperature, all_neg=True)
|
| 37 |
+
cnt_valid_times += 1
|
| 38 |
+
|
| 39 |
+
if cnt_valid_times > 0:
|
| 40 |
+
loss_UWE *= (self.weight_UWE / cnt_valid_times)
|
| 41 |
+
|
| 42 |
+
return loss_UWE
|
| 43 |
+
|
| 44 |
+
def get_topk_indices(self, beta):
|
| 45 |
+
# topk_indices: T x K x neg_topk
|
| 46 |
+
topk_indices = torch.topk(beta, k=self.neg_topk, dim=-1).indices
|
| 47 |
+
topk_indices = torch.flatten(topk_indices, start_dim=1)
|
| 48 |
+
return topk_indices
|
backend/models/CFDTM/__init__.py
ADDED
|
File without changes
|
backend/models/CFDTM/__pycache__/CFDTM.cpython-39.pyc
ADDED
|
Binary file (4.01 kB). View file
|
|
|
backend/models/CFDTM/__pycache__/ETC.cpython-39.pyc
ADDED
|
Binary file (1.85 kB). View file
|
|
|
backend/models/CFDTM/__pycache__/Encoder.cpython-39.pyc
ADDED
|
Binary file (1.52 kB). View file
|
|
|
backend/models/CFDTM/__pycache__/UWE.cpython-39.pyc
ADDED
|
Binary file (1.56 kB). View file
|
|
|
backend/models/CFDTM/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (158 Bytes). View file
|
|
|
backend/models/DBERTopic_trainer.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from bertopic import BERTopic
|
| 3 |
+
from backend.datasets.utils import _utils
|
| 4 |
+
from backend.datasets.utils.logger import Logger
|
| 5 |
+
|
| 6 |
+
logger = Logger("WARNING")
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class DBERTopicTrainer:
|
| 10 |
+
def __init__(self,
|
| 11 |
+
dataset,
|
| 12 |
+
num_topics=20,
|
| 13 |
+
num_top_words=15,
|
| 14 |
+
nr_bins=20,
|
| 15 |
+
global_tuning=True,
|
| 16 |
+
evolution_tuning=True,
|
| 17 |
+
datetime_format=None,
|
| 18 |
+
verbose=False):
|
| 19 |
+
|
| 20 |
+
self.dataset = dataset
|
| 21 |
+
self.docs = dataset.raw_documents
|
| 22 |
+
self.num_topics=num_topics
|
| 23 |
+
# self.timestamps = dataset.train_times
|
| 24 |
+
self.vocab = dataset.vocab
|
| 25 |
+
self.num_top_words = num_top_words
|
| 26 |
+
# self.nr_bins = nr_bins
|
| 27 |
+
# self.global_tuning = global_tuning
|
| 28 |
+
# self.evolution_tuning = evolution_tuning
|
| 29 |
+
# self.datetime_format = datetime_format
|
| 30 |
+
self.verbose = verbose
|
| 31 |
+
|
| 32 |
+
if verbose:
|
| 33 |
+
logger.set_level("DEBUG")
|
| 34 |
+
else:
|
| 35 |
+
logger.set_level("WARNING")
|
| 36 |
+
|
| 37 |
+
def train(self, timestamps, datetime_format='%Y'):
|
| 38 |
+
logger.info("Fitting BERTopic...")
|
| 39 |
+
self.model = BERTopic(nr_topics=self.num_topics, verbose=self.verbose)
|
| 40 |
+
self.topics, _ = self.model.fit_transform(self.docs)
|
| 41 |
+
|
| 42 |
+
logger.info("Running topics_over_time...")
|
| 43 |
+
self.topics_over_time_df = self.model.topics_over_time(
|
| 44 |
+
docs=self.docs,
|
| 45 |
+
timestamps=timestamps,
|
| 46 |
+
nr_bins=len(set(timestamps)),
|
| 47 |
+
datetime_format=datetime_format
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
self.unique_timestamps = sorted(self.topics_over_time_df["Timestamp"].unique())
|
| 51 |
+
self.unique_topics = sorted(self.topics_over_time_df["Topic"].unique())
|
| 52 |
+
self.vocab = self.model.vectorizer_model.get_feature_names_out()
|
| 53 |
+
self.V = len(self.vocab)
|
| 54 |
+
self.K = len(self.unique_topics)
|
| 55 |
+
self.T = len(self.unique_timestamps)
|
| 56 |
+
|
| 57 |
+
def get_beta(self):
|
| 58 |
+
logger.info("Generating β matrix...")
|
| 59 |
+
|
| 60 |
+
beta = np.zeros((self.T, self.K, self.V))
|
| 61 |
+
topic_to_index = {topic: idx for idx, topic in enumerate(self.unique_topics)}
|
| 62 |
+
timestamp_to_index = {timestamp: idx for idx, timestamp in enumerate(self.unique_timestamps)}
|
| 63 |
+
|
| 64 |
+
# Extract topic representations at each time
|
| 65 |
+
for t_idx, timestamp in enumerate(self.unique_timestamps):
|
| 66 |
+
selection = self.topics_over_time_df[self.topics_over_time_df["Timestamp"] == timestamp]
|
| 67 |
+
for _, row in selection.iterrows():
|
| 68 |
+
topic = row["Topic"]
|
| 69 |
+
words = row["Words"].split(", ")
|
| 70 |
+
if topic not in topic_to_index:
|
| 71 |
+
continue
|
| 72 |
+
k = topic_to_index[topic]
|
| 73 |
+
for word in words:
|
| 74 |
+
if word in self.vocab:
|
| 75 |
+
v = np.where(self.vocab == word)[0][0]
|
| 76 |
+
beta[t_idx, k, v] += 1.0
|
| 77 |
+
|
| 78 |
+
# Normalize each β_tk to be a probability distribution
|
| 79 |
+
beta = beta / (beta.sum(axis=2, keepdims=True) + 1e-10)
|
| 80 |
+
return beta
|
| 81 |
+
|
| 82 |
+
def get_top_words(self, num_top_words=None):
|
| 83 |
+
if num_top_words is None:
|
| 84 |
+
num_top_words = self.num_top_words
|
| 85 |
+
beta = self.get_beta()
|
| 86 |
+
top_words_list = list()
|
| 87 |
+
for time in range(beta.shape[0]):
|
| 88 |
+
top_words = _utils.get_top_words(beta[time], self.vocab, num_top_words, self.verbose)
|
| 89 |
+
top_words_list.append(top_words)
|
| 90 |
+
return top_words_list
|
| 91 |
+
|
| 92 |
+
def get_theta(self):
|
| 93 |
+
# Not applicable for BERTopic; can return topic assignments or soft topic distributions if required
|
| 94 |
+
logger.warning("get_theta is not implemented for BERTopic.")
|
| 95 |
+
return None
|
| 96 |
+
|
| 97 |
+
def export_theta(self):
|
| 98 |
+
logger.warning("export_theta is not implemented for BERTopic.")
|
| 99 |
+
return None, None
|
backend/models/DETM.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class DETM(nn.Module):
|
| 8 |
+
"""
|
| 9 |
+
The Dynamic Embedded Topic Model. 2019
|
| 10 |
+
|
| 11 |
+
Adji B. Dieng, Francisco J. R. Ruiz, David M. Blei
|
| 12 |
+
"""
|
| 13 |
+
def __init__(self, vocab_size, num_times, train_size, train_time_wordfreq,
|
| 14 |
+
num_topics=50, train_WE=True, pretrained_WE=None, en_units=800,
|
| 15 |
+
eta_hidden_size=200, rho_size=300, enc_drop=0.0, eta_nlayers=3,
|
| 16 |
+
eta_dropout=0.0, delta=0.005, theta_act='relu', device='cpu'):
|
| 17 |
+
super().__init__()
|
| 18 |
+
|
| 19 |
+
## define hyperparameters
|
| 20 |
+
self.num_topics = num_topics
|
| 21 |
+
self.num_times = num_times
|
| 22 |
+
self.vocab_size = vocab_size
|
| 23 |
+
self.eta_hidden_size = eta_hidden_size
|
| 24 |
+
self.rho_size = rho_size
|
| 25 |
+
self.enc_drop = enc_drop
|
| 26 |
+
self.eta_nlayers = eta_nlayers
|
| 27 |
+
self.t_drop = nn.Dropout(enc_drop)
|
| 28 |
+
self.eta_dropout = eta_dropout
|
| 29 |
+
self.delta = delta
|
| 30 |
+
self.train_WE = train_WE
|
| 31 |
+
self.train_size = train_size
|
| 32 |
+
self.rnn_inp = train_time_wordfreq
|
| 33 |
+
self.device = device
|
| 34 |
+
|
| 35 |
+
self.theta_act = self.get_activation(theta_act)
|
| 36 |
+
|
| 37 |
+
## define the word embedding matrix \rho
|
| 38 |
+
if self.train_WE:
|
| 39 |
+
self.rho = nn.Linear(self.rho_size, self.vocab_size, bias=False)
|
| 40 |
+
else:
|
| 41 |
+
rho = nn.Embedding(pretrained_WE.size())
|
| 42 |
+
rho.weight.data = torch.from_numpy(pretrained_WE)
|
| 43 |
+
self.rho = rho.weight.data.clone().float().to(self.device)
|
| 44 |
+
|
| 45 |
+
## define the variational parameters for the topic embeddings over time (alpha) ... alpha is K x T x L
|
| 46 |
+
self.mu_q_alpha = nn.Parameter(torch.randn(self.num_topics, self.num_times, self.rho_size))
|
| 47 |
+
self.logsigma_q_alpha = nn.Parameter(torch.randn(self.num_topics, self.num_times, self.rho_size))
|
| 48 |
+
|
| 49 |
+
## define variational distribution for \theta_{1:D} via amortizartion... theta is K x D
|
| 50 |
+
self.q_theta = nn.Sequential(
|
| 51 |
+
nn.Linear(self.vocab_size + self.num_topics, en_units),
|
| 52 |
+
self.theta_act,
|
| 53 |
+
nn.Linear(en_units, en_units),
|
| 54 |
+
self.theta_act,
|
| 55 |
+
)
|
| 56 |
+
self.mu_q_theta = nn.Linear(en_units, self.num_topics, bias=True)
|
| 57 |
+
self.logsigma_q_theta = nn.Linear(en_units, self.num_topics, bias=True)
|
| 58 |
+
|
| 59 |
+
## define variational distribution for \eta via amortizartion... eta is K x T
|
| 60 |
+
self.q_eta_map = nn.Linear(self.vocab_size, self.eta_hidden_size)
|
| 61 |
+
self.q_eta = nn.LSTM(self.eta_hidden_size, self.eta_hidden_size, self.eta_nlayers, dropout=self.eta_dropout)
|
| 62 |
+
self.mu_q_eta = nn.Linear(self.eta_hidden_size + self.num_topics, self.num_topics, bias=True)
|
| 63 |
+
self.logsigma_q_eta = nn.Linear(self.eta_hidden_size + self.num_topics, self.num_topics, bias=True)
|
| 64 |
+
|
| 65 |
+
self.decoder_bn = nn.BatchNorm1d(vocab_size)
|
| 66 |
+
self.decoder_bn.weight.requires_grad = False
|
| 67 |
+
|
| 68 |
+
def get_activation(self, act):
|
| 69 |
+
activations = {
|
| 70 |
+
'tanh': nn.Tanh(),
|
| 71 |
+
'relu': nn.ReLU(),
|
| 72 |
+
'softplus': nn.Softplus(),
|
| 73 |
+
'rrelu': nn.RReLU(),
|
| 74 |
+
'leakyrelu': nn.LeakyReLU(),
|
| 75 |
+
'elu': nn.ELU(),
|
| 76 |
+
'selu': nn.SELU(),
|
| 77 |
+
'glu': nn.GLU(),
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
if act in activations:
|
| 81 |
+
act = activations[act]
|
| 82 |
+
else:
|
| 83 |
+
print('Defaulting to tanh activations...')
|
| 84 |
+
act = nn.Tanh()
|
| 85 |
+
return act
|
| 86 |
+
|
| 87 |
+
def reparameterize(self, mu, logvar):
|
| 88 |
+
"""Returns a sample from a Gaussian distribution via reparameterization.
|
| 89 |
+
"""
|
| 90 |
+
if self.training:
|
| 91 |
+
std = torch.exp(0.5 * logvar)
|
| 92 |
+
eps = torch.randn_like(std)
|
| 93 |
+
return eps.mul_(std).add_(mu)
|
| 94 |
+
else:
|
| 95 |
+
return mu
|
| 96 |
+
|
| 97 |
+
def get_kl(self, q_mu, q_logsigma, p_mu=None, p_logsigma=None):
|
| 98 |
+
"""Returns KL( N(q_mu, q_logsigma) || N(p_mu, p_logsigma) ).
|
| 99 |
+
"""
|
| 100 |
+
if p_mu is not None and p_logsigma is not None:
|
| 101 |
+
sigma_q_sq = torch.exp(q_logsigma)
|
| 102 |
+
sigma_p_sq = torch.exp(p_logsigma)
|
| 103 |
+
kl = ( sigma_q_sq + (q_mu - p_mu)**2 ) / ( sigma_p_sq + 1e-6 )
|
| 104 |
+
kl = kl - 1 + p_logsigma - q_logsigma
|
| 105 |
+
kl = 0.5 * torch.sum(kl, dim=-1)
|
| 106 |
+
else:
|
| 107 |
+
kl = -0.5 * torch.sum(1 + q_logsigma - q_mu.pow(2) - q_logsigma.exp(), dim=-1)
|
| 108 |
+
return kl
|
| 109 |
+
|
| 110 |
+
def get_alpha(self): ## mean field
|
| 111 |
+
alphas = torch.zeros(self.num_times, self.num_topics, self.rho_size).to(self.device)
|
| 112 |
+
kl_alpha = []
|
| 113 |
+
|
| 114 |
+
alphas[0] = self.reparameterize(self.mu_q_alpha[:, 0, :], self.logsigma_q_alpha[:, 0, :])
|
| 115 |
+
|
| 116 |
+
# TODO: why logsigma_p_0 is zero?
|
| 117 |
+
p_mu_0 = torch.zeros(self.num_topics, self.rho_size).to(self.device)
|
| 118 |
+
logsigma_p_0 = torch.zeros(self.num_topics, self.rho_size).to(self.device)
|
| 119 |
+
kl_0 = self.get_kl(self.mu_q_alpha[:, 0, :], self.logsigma_q_alpha[:, 0, :], p_mu_0, logsigma_p_0)
|
| 120 |
+
kl_alpha.append(kl_0)
|
| 121 |
+
for t in range(1, self.num_times):
|
| 122 |
+
alphas[t] = self.reparameterize(self.mu_q_alpha[:, t, :], self.logsigma_q_alpha[:, t, :])
|
| 123 |
+
|
| 124 |
+
p_mu_t = alphas[t - 1]
|
| 125 |
+
logsigma_p_t = torch.log(self.delta * torch.ones(self.num_topics, self.rho_size).to(self.device))
|
| 126 |
+
kl_t = self.get_kl(self.mu_q_alpha[:, t, :], self.logsigma_q_alpha[:, t, :], p_mu_t, logsigma_p_t)
|
| 127 |
+
kl_alpha.append(kl_t)
|
| 128 |
+
kl_alpha = torch.stack(kl_alpha).sum()
|
| 129 |
+
return alphas, kl_alpha.sum()
|
| 130 |
+
|
| 131 |
+
def get_eta(self, rnn_inp): ## structured amortized inference
|
| 132 |
+
inp = self.q_eta_map(rnn_inp).unsqueeze(1)
|
| 133 |
+
hidden = self.init_hidden()
|
| 134 |
+
output, _ = self.q_eta(inp, hidden)
|
| 135 |
+
output = output.squeeze()
|
| 136 |
+
|
| 137 |
+
etas = torch.zeros(self.num_times, self.num_topics).to(self.device)
|
| 138 |
+
kl_eta = []
|
| 139 |
+
|
| 140 |
+
inp_0 = torch.cat([output[0], torch.zeros(self.num_topics,).to(self.device)], dim=0)
|
| 141 |
+
mu_0 = self.mu_q_eta(inp_0)
|
| 142 |
+
logsigma_0 = self.logsigma_q_eta(inp_0)
|
| 143 |
+
etas[0] = self.reparameterize(mu_0, logsigma_0)
|
| 144 |
+
|
| 145 |
+
p_mu_0 = torch.zeros(self.num_topics,).to(self.device)
|
| 146 |
+
logsigma_p_0 = torch.zeros(self.num_topics,).to(self.device)
|
| 147 |
+
kl_0 = self.get_kl(mu_0, logsigma_0, p_mu_0, logsigma_p_0)
|
| 148 |
+
kl_eta.append(kl_0)
|
| 149 |
+
|
| 150 |
+
for t in range(1, self.num_times):
|
| 151 |
+
inp_t = torch.cat([output[t], etas[t-1]], dim=0)
|
| 152 |
+
mu_t = self.mu_q_eta(inp_t)
|
| 153 |
+
logsigma_t = self.logsigma_q_eta(inp_t)
|
| 154 |
+
etas[t] = self.reparameterize(mu_t, logsigma_t)
|
| 155 |
+
|
| 156 |
+
p_mu_t = etas[t-1]
|
| 157 |
+
logsigma_p_t = torch.log(self.delta * torch.ones(self.num_topics,).to(self.device))
|
| 158 |
+
kl_t = self.get_kl(mu_t, logsigma_t, p_mu_t, logsigma_p_t)
|
| 159 |
+
kl_eta.append(kl_t)
|
| 160 |
+
kl_eta = torch.stack(kl_eta).sum()
|
| 161 |
+
|
| 162 |
+
return etas, kl_eta
|
| 163 |
+
|
| 164 |
+
def get_theta(self, bows, times, eta=None): ## amortized inference
|
| 165 |
+
"""Returns the topic proportions.
|
| 166 |
+
"""
|
| 167 |
+
|
| 168 |
+
normalized_bows = bows / bows.sum(1, keepdims=True)
|
| 169 |
+
|
| 170 |
+
if eta is None and self.training is False:
|
| 171 |
+
eta, kl_eta = self.get_eta(self.rnn_inp)
|
| 172 |
+
|
| 173 |
+
eta_td = eta[times]
|
| 174 |
+
inp = torch.cat([normalized_bows, eta_td], dim=1)
|
| 175 |
+
q_theta = self.q_theta(inp)
|
| 176 |
+
if self.enc_drop > 0:
|
| 177 |
+
q_theta = self.t_drop(q_theta)
|
| 178 |
+
mu_theta = self.mu_q_theta(q_theta)
|
| 179 |
+
logsigma_theta = self.logsigma_q_theta(q_theta)
|
| 180 |
+
z = self.reparameterize(mu_theta, logsigma_theta)
|
| 181 |
+
theta = F.softmax(z, dim=-1)
|
| 182 |
+
kl_theta = self.get_kl(mu_theta, logsigma_theta, eta_td, torch.zeros(self.num_topics).to(self.device))
|
| 183 |
+
|
| 184 |
+
if self.training:
|
| 185 |
+
return theta, kl_theta
|
| 186 |
+
else:
|
| 187 |
+
return theta
|
| 188 |
+
|
| 189 |
+
@property
|
| 190 |
+
def word_embeddings(self):
|
| 191 |
+
return self.rho.weight
|
| 192 |
+
|
| 193 |
+
@property
|
| 194 |
+
def topic_embeddings(self):
|
| 195 |
+
alpha, _ = self.get_alpha()
|
| 196 |
+
return alpha
|
| 197 |
+
|
| 198 |
+
def get_beta(self, alpha=None):
|
| 199 |
+
"""Returns the topic matrix \beta of shape T x K x V
|
| 200 |
+
"""
|
| 201 |
+
|
| 202 |
+
if alpha is None and self.training is False:
|
| 203 |
+
alpha, kl_alpha = self.get_alpha()
|
| 204 |
+
|
| 205 |
+
if self.train_WE:
|
| 206 |
+
logit = self.rho(alpha.view(alpha.size(0) * alpha.size(1), self.rho_size))
|
| 207 |
+
else:
|
| 208 |
+
tmp = alpha.view(alpha.size(0) * alpha.size(1), self.rho_size)
|
| 209 |
+
logit = torch.mm(tmp, self.rho.permute(1, 0))
|
| 210 |
+
logit = logit.view(alpha.size(0), alpha.size(1), -1)
|
| 211 |
+
|
| 212 |
+
beta = F.softmax(logit, dim=-1)
|
| 213 |
+
|
| 214 |
+
return beta
|
| 215 |
+
|
| 216 |
+
def get_NLL(self, theta, beta, bows):
|
| 217 |
+
theta = theta.unsqueeze(1)
|
| 218 |
+
loglik = torch.bmm(theta, beta).squeeze(1)
|
| 219 |
+
loglik = torch.log(loglik + 1e-12)
|
| 220 |
+
nll = -loglik * bows
|
| 221 |
+
nll = nll.sum(-1)
|
| 222 |
+
return nll
|
| 223 |
+
|
| 224 |
+
def forward(self, bows, times):
|
| 225 |
+
bsz = bows.size(0)
|
| 226 |
+
coeff = self.train_size / bsz
|
| 227 |
+
eta, kl_eta = self.get_eta(self.rnn_inp)
|
| 228 |
+
theta, kl_theta = self.get_theta(bows, times, eta)
|
| 229 |
+
kl_theta = kl_theta.sum() * coeff
|
| 230 |
+
|
| 231 |
+
alpha, kl_alpha = self.get_alpha()
|
| 232 |
+
beta = self.get_beta(alpha)
|
| 233 |
+
|
| 234 |
+
beta = beta[times]
|
| 235 |
+
# beta = beta[times.type('torch.LongTensor')]
|
| 236 |
+
nll = self.get_NLL(theta, beta, bows)
|
| 237 |
+
nll = nll.sum() * coeff
|
| 238 |
+
|
| 239 |
+
loss = nll + kl_eta + kl_theta
|
| 240 |
+
|
| 241 |
+
rst_dict = {
|
| 242 |
+
'loss': loss,
|
| 243 |
+
'nll': nll,
|
| 244 |
+
'kl_eta': kl_eta,
|
| 245 |
+
'kl_theta': kl_theta
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
loss += kl_alpha
|
| 249 |
+
rst_dict['kl_alpha'] = kl_alpha
|
| 250 |
+
|
| 251 |
+
return rst_dict
|
| 252 |
+
|
| 253 |
+
def init_hidden(self):
|
| 254 |
+
"""Initializes the first hidden state of the RNN used as inference network for \\eta.
|
| 255 |
+
"""
|
| 256 |
+
weight = next(self.parameters())
|
| 257 |
+
nlayers = self.eta_nlayers
|
| 258 |
+
nhid = self.eta_hidden_size
|
| 259 |
+
return (weight.new_zeros(nlayers, 1, nhid), weight.new_zeros(nlayers, 1, nhid))
|
backend/models/DTM_trainer.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gensim
|
| 2 |
+
import numpy as np
|
| 3 |
+
from gensim.models import ldaseqmodel
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
import datetime
|
| 6 |
+
from multiprocessing.pool import Pool
|
| 7 |
+
from backend.datasets.utils import _utils
|
| 8 |
+
from backend.datasets.utils.logger import Logger
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
logger = Logger("WARNING")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def work(arguments):
|
| 15 |
+
model, docs = arguments
|
| 16 |
+
theta_list = list()
|
| 17 |
+
for doc in tqdm(docs):
|
| 18 |
+
theta_list.append(model[doc])
|
| 19 |
+
return theta_list
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class DTMTrainer:
|
| 23 |
+
def __init__(self,
|
| 24 |
+
dataset,
|
| 25 |
+
num_topics=50,
|
| 26 |
+
num_top_words=15,
|
| 27 |
+
alphas=0.01,
|
| 28 |
+
chain_variance=0.005,
|
| 29 |
+
passes=10,
|
| 30 |
+
lda_inference_max_iter=25,
|
| 31 |
+
em_min_iter=6,
|
| 32 |
+
em_max_iter=20,
|
| 33 |
+
verbose=False
|
| 34 |
+
):
|
| 35 |
+
|
| 36 |
+
self.dataset = dataset
|
| 37 |
+
self.vocab_size = dataset.vocab_size
|
| 38 |
+
self.num_topics = num_topics
|
| 39 |
+
self.num_top_words = num_top_words
|
| 40 |
+
self.alphas = alphas
|
| 41 |
+
self.chain_variance = chain_variance
|
| 42 |
+
self.passes = passes
|
| 43 |
+
self.lda_inference_max_iter = lda_inference_max_iter
|
| 44 |
+
self.em_min_iter = em_min_iter
|
| 45 |
+
self.em_max_iter = em_max_iter
|
| 46 |
+
|
| 47 |
+
self.verbose = verbose
|
| 48 |
+
if verbose:
|
| 49 |
+
logger.set_level("DEBUG")
|
| 50 |
+
else:
|
| 51 |
+
logger.set_level("WARNING")
|
| 52 |
+
|
| 53 |
+
def train(self):
|
| 54 |
+
id2word = dict(zip(range(self.vocab_size), self.dataset.vocab))
|
| 55 |
+
train_bow = self.dataset.train_bow
|
| 56 |
+
train_times = self.dataset.train_times.astype('int32')
|
| 57 |
+
|
| 58 |
+
# order documents by time slices
|
| 59 |
+
self.doc_order_idx = np.argsort(train_times)
|
| 60 |
+
train_bow = train_bow[self.doc_order_idx]
|
| 61 |
+
time_slices = np.bincount(train_times)
|
| 62 |
+
|
| 63 |
+
corpus = gensim.matutils.Dense2Corpus(train_bow, documents_columns=False)
|
| 64 |
+
|
| 65 |
+
self.model = ldaseqmodel.LdaSeqModel(
|
| 66 |
+
corpus=corpus,
|
| 67 |
+
id2word=id2word,
|
| 68 |
+
time_slice=time_slices,
|
| 69 |
+
num_topics=self.num_topics,
|
| 70 |
+
alphas=self.alphas,
|
| 71 |
+
chain_variance=self.chain_variance,
|
| 72 |
+
em_min_iter=self.em_min_iter,
|
| 73 |
+
em_max_iter=self.em_max_iter,
|
| 74 |
+
lda_inference_max_iter=self.lda_inference_max_iter,
|
| 75 |
+
passes=self.passes
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
def test(self, bow):
|
| 79 |
+
# bow = dataset.bow.cpu().numpy()
|
| 80 |
+
# times = dataset.times.cpu().numpy()
|
| 81 |
+
corpus = gensim.matutils.Dense2Corpus(bow, documents_columns=False)
|
| 82 |
+
|
| 83 |
+
num_workers = 20
|
| 84 |
+
split_idx_list = np.array_split(np.arange(len(bow)), num_workers)
|
| 85 |
+
worker_size_list = [len(x) for x in split_idx_list]
|
| 86 |
+
|
| 87 |
+
worker_id = 0
|
| 88 |
+
docs_list = [list() for i in range(num_workers)]
|
| 89 |
+
for i, doc in enumerate(corpus):
|
| 90 |
+
docs_list[worker_id].append(doc)
|
| 91 |
+
if len(docs_list[worker_id]) >= worker_size_list[worker_id]:
|
| 92 |
+
worker_id += 1
|
| 93 |
+
|
| 94 |
+
args_list = list()
|
| 95 |
+
for docs in docs_list:
|
| 96 |
+
args_list.append([self.model, docs])
|
| 97 |
+
|
| 98 |
+
starttime = datetime.datetime.now()
|
| 99 |
+
|
| 100 |
+
pool = Pool(processes=num_workers)
|
| 101 |
+
results = pool.map(work, args_list)
|
| 102 |
+
|
| 103 |
+
pool.close()
|
| 104 |
+
pool.join()
|
| 105 |
+
|
| 106 |
+
theta_list = list()
|
| 107 |
+
for rst in results:
|
| 108 |
+
theta_list.extend(rst)
|
| 109 |
+
|
| 110 |
+
endtime = datetime.datetime.now()
|
| 111 |
+
|
| 112 |
+
print("DTM test time: {}s".format((endtime - starttime).seconds))
|
| 113 |
+
|
| 114 |
+
return np.asarray(theta_list)
|
| 115 |
+
|
| 116 |
+
def get_theta(self):
|
| 117 |
+
theta = self.model.gammas / self.model.gammas.sum(axis=1)[:, np.newaxis]
|
| 118 |
+
# NOTE: MUST transform gamma to original order.
|
| 119 |
+
return theta[np.argsort(self.doc_order_idx)]
|
| 120 |
+
|
| 121 |
+
def get_beta(self):
|
| 122 |
+
beta = list()
|
| 123 |
+
# K x V x T
|
| 124 |
+
for item in self.model.topic_chains:
|
| 125 |
+
# V x T
|
| 126 |
+
beta.append(item.e_log_prob)
|
| 127 |
+
|
| 128 |
+
# T x K x V
|
| 129 |
+
beta = np.transpose(np.asarray(beta), (2, 0, 1))
|
| 130 |
+
# use softmax
|
| 131 |
+
beta = np.exp(beta)
|
| 132 |
+
beta = beta / beta.sum(-1, keepdims=True)
|
| 133 |
+
return beta
|
| 134 |
+
|
| 135 |
+
def get_top_words(self, num_top_words=None):
|
| 136 |
+
if num_top_words is None:
|
| 137 |
+
num_top_words = self.num_top_words
|
| 138 |
+
beta = self.get_beta()
|
| 139 |
+
top_words_list = list()
|
| 140 |
+
for time in range(beta.shape[0]):
|
| 141 |
+
top_words = _utils.get_top_words(beta[time], self.dataset.vocab, num_top_words, self.verbose)
|
| 142 |
+
top_words_list.append(top_words)
|
| 143 |
+
return top_words_list
|
| 144 |
+
|
| 145 |
+
def export_theta(self):
|
| 146 |
+
train_theta = self.get_theta()
|
| 147 |
+
test_theta = self.test(self.dataset.test_bow)
|
| 148 |
+
return train_theta, test_theta
|
backend/models/dynamic_trainer.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from tqdm import tqdm
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch.optim.lr_scheduler import StepLR
|
| 7 |
+
from backend.datasets.utils import _utils
|
| 8 |
+
from backend.datasets.utils.logger import Logger
|
| 9 |
+
|
| 10 |
+
logger = Logger("WARNING")
|
| 11 |
+
|
| 12 |
+
class DynamicTrainer:
|
| 13 |
+
def __init__(self,
|
| 14 |
+
model,
|
| 15 |
+
dataset,
|
| 16 |
+
num_top_words=15,
|
| 17 |
+
epochs=200,
|
| 18 |
+
learning_rate=0.002,
|
| 19 |
+
batch_size=200,
|
| 20 |
+
lr_scheduler=None,
|
| 21 |
+
lr_step_size=125,
|
| 22 |
+
log_interval=5,
|
| 23 |
+
verbose=False
|
| 24 |
+
):
|
| 25 |
+
|
| 26 |
+
self.model = model
|
| 27 |
+
self.dataset = dataset
|
| 28 |
+
self.num_top_words = num_top_words
|
| 29 |
+
self.epochs = epochs
|
| 30 |
+
self.learning_rate = learning_rate
|
| 31 |
+
self.batch_size = batch_size
|
| 32 |
+
self.lr_scheduler = lr_scheduler
|
| 33 |
+
self.lr_step_size = lr_step_size
|
| 34 |
+
self.log_interval = log_interval
|
| 35 |
+
|
| 36 |
+
self.verbose = verbose
|
| 37 |
+
if verbose:
|
| 38 |
+
logger.set_level("DEBUG")
|
| 39 |
+
else:
|
| 40 |
+
logger.set_level("WARNING")
|
| 41 |
+
|
| 42 |
+
def make_optimizer(self,):
|
| 43 |
+
args_dict = {
|
| 44 |
+
'params': self.model.parameters(),
|
| 45 |
+
'lr': self.learning_rate,
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
optimizer = torch.optim.Adam(**args_dict)
|
| 49 |
+
return optimizer
|
| 50 |
+
|
| 51 |
+
def make_lr_scheduler(self, optimizer):
|
| 52 |
+
lr_scheduler = StepLR(optimizer, step_size=self.lr_step_size, gamma=0.5, verbose=False)
|
| 53 |
+
return lr_scheduler
|
| 54 |
+
|
| 55 |
+
def train(self):
|
| 56 |
+
optimizer = self.make_optimizer()
|
| 57 |
+
|
| 58 |
+
if self.lr_scheduler:
|
| 59 |
+
logger.info("using lr_scheduler")
|
| 60 |
+
lr_scheduler = self.make_lr_scheduler(optimizer)
|
| 61 |
+
|
| 62 |
+
data_size = len(self.dataset.train_dataloader.dataset)
|
| 63 |
+
|
| 64 |
+
for epoch in tqdm(range(1, self.epochs + 1)):
|
| 65 |
+
self.model.train()
|
| 66 |
+
loss_rst_dict = defaultdict(float)
|
| 67 |
+
|
| 68 |
+
for batch_data in self.dataset.train_dataloader:
|
| 69 |
+
|
| 70 |
+
rst_dict = self.model(batch_data['bow'], batch_data['times'])
|
| 71 |
+
batch_loss = rst_dict['loss']
|
| 72 |
+
|
| 73 |
+
optimizer.zero_grad()
|
| 74 |
+
batch_loss.backward()
|
| 75 |
+
optimizer.step()
|
| 76 |
+
|
| 77 |
+
for key in rst_dict:
|
| 78 |
+
loss_rst_dict[key] += rst_dict[key] * len(batch_data)
|
| 79 |
+
|
| 80 |
+
if self.lr_scheduler:
|
| 81 |
+
lr_scheduler.step()
|
| 82 |
+
|
| 83 |
+
if epoch % self.log_interval == 0:
|
| 84 |
+
output_log = f'Epoch: {epoch:03d}'
|
| 85 |
+
for key in loss_rst_dict:
|
| 86 |
+
output_log += f' {key}: {loss_rst_dict[key] / data_size :.3f}'
|
| 87 |
+
|
| 88 |
+
logger.info(output_log)
|
| 89 |
+
|
| 90 |
+
top_words = self.get_top_words()
|
| 91 |
+
train_theta = self.test(self.dataset.train_bow, self.dataset.train_times)
|
| 92 |
+
|
| 93 |
+
return top_words, train_theta
|
| 94 |
+
|
| 95 |
+
def test(self, bow, times):
|
| 96 |
+
data_size = bow.shape[0]
|
| 97 |
+
theta = list()
|
| 98 |
+
all_idx = torch.split(torch.arange(data_size), self.batch_size)
|
| 99 |
+
|
| 100 |
+
with torch.no_grad():
|
| 101 |
+
self.model.eval()
|
| 102 |
+
for idx in all_idx:
|
| 103 |
+
batch_theta = self.model.get_theta(bow[idx], times[idx])
|
| 104 |
+
theta.extend(batch_theta.cpu().tolist())
|
| 105 |
+
|
| 106 |
+
theta = np.asarray(theta)
|
| 107 |
+
return theta
|
| 108 |
+
|
| 109 |
+
def get_beta(self):
|
| 110 |
+
self.model.eval()
|
| 111 |
+
beta = self.model.get_beta().detach().cpu().numpy()
|
| 112 |
+
return beta
|
| 113 |
+
|
| 114 |
+
def get_top_words(self, num_top_words=None):
|
| 115 |
+
if num_top_words is None:
|
| 116 |
+
num_top_words = self.num_top_words
|
| 117 |
+
|
| 118 |
+
beta = self.get_beta()
|
| 119 |
+
top_words_list = list()
|
| 120 |
+
for time in range(beta.shape[0]):
|
| 121 |
+
if self.verbose:
|
| 122 |
+
print(f"======= Time: {time} =======")
|
| 123 |
+
top_words = _utils.get_top_words(beta[time], self.dataset.vocab, num_top_words, self.verbose)
|
| 124 |
+
top_words_list.append(top_words)
|
| 125 |
+
return top_words_list
|
| 126 |
+
|
| 127 |
+
def export_theta(self):
|
| 128 |
+
train_theta = self.test(self.dataset.train_bow, self.dataset.train_times)
|
| 129 |
+
test_theta = self.test(self.dataset.test_bow, self.dataset.test_times)
|
| 130 |
+
|
| 131 |
+
return train_theta, test_theta
|
| 132 |
+
|
| 133 |
+
def get_top_words_at_time(self, topic_id, time, top_n):
|
| 134 |
+
beta = self.get_beta() # shape: [T, K, V]
|
| 135 |
+
topic_beta = beta[time, topic_id, :]
|
| 136 |
+
top_indices = topic_beta.argsort()[-top_n:][::-1]
|
| 137 |
+
return [self.dataset.vocab[i] for i in top_indices]
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def get_topic_words_over_time(self, topic_id, top_n):
|
| 141 |
+
"""
|
| 142 |
+
Returns top_n words for the given topic_id over all time steps.
|
| 143 |
+
Output: List[List[str]], each inner list is the top_n words at a time step.
|
| 144 |
+
"""
|
| 145 |
+
beta = self.get_beta() # shape: [T, K, V]
|
| 146 |
+
T = beta.shape[0]
|
| 147 |
+
return [
|
| 148 |
+
self.get_top_words_at_time(topic_id=topic_id, time=t, top_n=top_n)
|
| 149 |
+
for t in range(T)
|
| 150 |
+
]
|
| 151 |
+
|
| 152 |
+
def get_all_topics_at_time(self, time, top_n):
|
| 153 |
+
"""
|
| 154 |
+
Returns top_n words for each topic at the given time step.
|
| 155 |
+
Output: List[List[str]], each inner list is the top_n words for a topic.
|
| 156 |
+
"""
|
| 157 |
+
beta = self.get_beta() # shape: [T, K, V]
|
| 158 |
+
K = beta.shape[1]
|
| 159 |
+
return [
|
| 160 |
+
self.get_top_words_at_time(topic_id=k, time=time, top_n=top_n)
|
| 161 |
+
for k in range(K)
|
| 162 |
+
]
|
| 163 |
+
|
| 164 |
+
def get_all_topics_over_time(self, top_n=10):
|
| 165 |
+
"""
|
| 166 |
+
Returns the top_n words for all topics over all time steps.
|
| 167 |
+
Output shape: List[List[List[str]]] = T x K x top_n
|
| 168 |
+
"""
|
| 169 |
+
beta = self.get_beta() # shape: [T, K, V]
|
| 170 |
+
T, K, _ = beta.shape
|
| 171 |
+
return [
|
| 172 |
+
[
|
| 173 |
+
self.get_top_words_at_time(topic_id=k, time=t, top_n=top_n)
|
| 174 |
+
for k in range(K)
|
| 175 |
+
]
|
| 176 |
+
for t in range(T)
|
| 177 |
+
]
|
data/ACL_Anthology/CFDTM/beta.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:34984bfb432a10733161a9dfed834a9ef4f366a28a6cb2ecd6e8351997f1599a
|
| 3 |
+
size 16645248
|
data/ACL_Anthology/DETM/beta.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3c6eefa9b6aaea4c694736d09ad9e517446f09929c01889e26633300e5eff166
|
| 3 |
+
size 41612928
|
data/ACL_Anthology/DTM/beta.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:14c296a2e3fb49f9d0b66262907d64f7d181408768e43138d57c262ea6a11318
|
| 3 |
+
size 33290368
|
data/ACL_Anthology/DTM/topic_label_cache.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ea9f3c508ede82967cdf02050d7383d58dd9d269a7f661ae1462a95cbac3331e
|
| 3 |
+
size 2089
|
data/ACL_Anthology/docs.jsonl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a004dd095b9a4f29fdccb5144d50d3dacc7985af443a8de434005b7b8401f9b7
|
| 3 |
+
size 67395059
|
data/ACL_Anthology/inverted_index.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:60e7ee888abb2fd025b11415a7ead6780d41c5f890cc25ba453615906f10b8d7
|
| 3 |
+
size 30865281
|
data/ACL_Anthology/processed/lemma_to_forms.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:00ea8855f9ced2ca3d785ce5926ced29b35e0779cd6b3166edfd5c5a5c1beccb
|
| 3 |
+
size 4370995
|
data/ACL_Anthology/processed/length_stats.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5cc985e5a1ce565ca4179d343ade1526daab463520f6317122953da83d368306
|
| 3 |
+
size 133
|
data/ACL_Anthology/processed/time2id.txt
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"2010": 0,
|
| 3 |
+
"2011": 1,
|
| 4 |
+
"2012": 2,
|
| 5 |
+
"2013": 3,
|
| 6 |
+
"2014": 4,
|
| 7 |
+
"2015": 5,
|
| 8 |
+
"2016": 6,
|
| 9 |
+
"2017": 7,
|
| 10 |
+
"2018": 8,
|
| 11 |
+
"2019": 9,
|
| 12 |
+
"2020": 10,
|
| 13 |
+
"2021": 11,
|
| 14 |
+
"2022": 12,
|
| 15 |
+
"2023": 13,
|
| 16 |
+
"2024": 14,
|
| 17 |
+
"2025": 15
|
| 18 |
+
}
|
data/ACL_Anthology/processed/vocab.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|