|
|
|
|
|
import streamlit as st |
|
|
|
|
|
st.set_page_config(page_title="SNAP", layout="wide") |
|
|
|
|
|
import warnings |
|
|
|
warnings.filterwarnings('ignore', message='.*torch.classes.*__path__._path.*') |
|
warnings.filterwarnings('ignore', message='.*torch.classes.*registered via torch::class_.*') |
|
|
|
import pandas as pd |
|
import numpy as np |
|
import os |
|
import io |
|
import time |
|
from datetime import datetime |
|
import base64 |
|
import re |
|
import pickle |
|
from typing import List, Dict, Any, Tuple |
|
import plotly.express as px |
|
import torch |
|
|
|
|
|
from concurrent.futures import ThreadPoolExecutor |
|
from functools import partial |
|
|
|
|
|
from sentence_transformers import SentenceTransformer |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
from bertopic import BERTopic |
|
from hdbscan import HDBSCAN |
|
import nltk |
|
from nltk.corpus import stopwords |
|
from nltk.tokenize import word_tokenize |
|
|
|
|
|
from langchain.chains import LLMChain |
|
from langchain_community.chat_models import ChatOpenAI |
|
from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate |
|
from openai import OpenAI |
|
from transformers import GPT2TokenizerFast |
|
|
|
|
|
client = OpenAI() |
|
|
|
|
|
|
|
|
|
def get_base_dir(): |
|
try: |
|
base_dir = os.path.dirname(__file__) |
|
if not base_dir: |
|
return os.getcwd() |
|
return base_dir |
|
except NameError: |
|
|
|
return os.getcwd() |
|
|
|
BASE_DIR = get_base_dir() |
|
|
|
|
|
def get_model_dir(): |
|
base_dir = get_base_dir() |
|
model_dir = os.path.join(base_dir, 'models') |
|
os.makedirs(model_dir, exist_ok=True) |
|
return model_dir |
|
|
|
|
|
def load_tokenizer(): |
|
model_dir = get_model_dir() |
|
tokenizer_dir = os.path.join(model_dir, 'tokenizer') |
|
os.makedirs(tokenizer_dir, exist_ok=True) |
|
|
|
try: |
|
|
|
tokenizer = GPT2TokenizerFast.from_pretrained(tokenizer_dir) |
|
|
|
except Exception as e: |
|
|
|
try: |
|
|
|
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") |
|
tokenizer.save_pretrained(tokenizer_dir) |
|
|
|
except Exception as download_e: |
|
|
|
raise |
|
|
|
return tokenizer |
|
|
|
|
|
try: |
|
tokenizer = load_tokenizer() |
|
except Exception as e: |
|
|
|
tokenizer = None |
|
|
|
MAX_CONTEXT_WINDOW = 128000 |
|
|
|
|
|
if 'chat_history' not in st.session_state: |
|
st.session_state.chat_history = [] |
|
|
|
|
|
|
|
|
|
def get_chat_response(messages): |
|
try: |
|
response = client.chat.completions.create( |
|
model="gpt-4o-mini", |
|
messages=messages, |
|
temperature=0, |
|
) |
|
return response.choices[0].message.content.strip() |
|
except Exception as e: |
|
st.error(f"Error querying OpenAI: {e}") |
|
return None |
|
|
|
|
|
|
|
|
|
def generate_raw_cluster_summary( |
|
topic_val: int, |
|
cluster_df: pd.DataFrame, |
|
llm: Any, |
|
chat_prompt: Any |
|
) -> Dict[str, Any]: |
|
"""Generate a summary for a single cluster without reference enhancement, |
|
automatically trimming text if it exceeds a safe token limit.""" |
|
cluster_text = " ".join(cluster_df['text'].tolist()) |
|
if not cluster_text.strip(): |
|
return None |
|
|
|
|
|
safe_limit = int(MAX_CONTEXT_WINDOW * 0.95) |
|
|
|
|
|
encoded_text = tokenizer.encode(cluster_text, add_special_tokens=False) |
|
|
|
|
|
if len(encoded_text) > safe_limit: |
|
|
|
encoded_text = encoded_text[:safe_limit] |
|
cluster_text = tokenizer.decode(encoded_text) |
|
|
|
user_prompt_local = f"**Text to summarize**: {cluster_text}" |
|
try: |
|
local_chain = LLMChain(llm=llm, prompt=chat_prompt) |
|
summary_local = local_chain.run(user_prompt=user_prompt_local).strip() |
|
return {'Topic': topic_val, 'Summary': summary_local} |
|
except Exception as e: |
|
st.error(f"Error generating summary for cluster {topic_val}: {str(e)}") |
|
return None |
|
|
|
|
|
|
|
|
|
def enhance_summary_with_references( |
|
summary_dict: Dict[str, Any], |
|
df_scope: pd.DataFrame, |
|
reference_id_column: str, |
|
url_column: str = None, |
|
llm: Any = None |
|
) -> Dict[str, Any]: |
|
"""Add references to a summary.""" |
|
if not summary_dict or 'Summary' not in summary_dict: |
|
return summary_dict |
|
|
|
try: |
|
cluster_df = df_scope[df_scope['Topic'] == summary_dict['Topic']] |
|
enhanced = add_references_to_summary( |
|
summary_dict['Summary'], |
|
cluster_df, |
|
reference_id_column, |
|
url_column, |
|
llm |
|
) |
|
summary_dict['Enhanced_Summary'] = enhanced |
|
return summary_dict |
|
except Exception as e: |
|
st.error(f"Error enhancing summary for cluster {summary_dict.get('Topic')}: {str(e)}") |
|
return summary_dict |
|
|
|
|
|
|
|
|
|
def process_summaries_in_parallel( |
|
df_scope: pd.DataFrame, |
|
unique_selected_topics: List[int], |
|
llm: Any, |
|
chat_prompt: Any, |
|
enable_references: bool = False, |
|
reference_id_column: str = None, |
|
url_column: str = None, |
|
max_workers: int = 16 |
|
) -> List[Dict[str, Any]]: |
|
"""Process multiple cluster summaries in parallel using ThreadPoolExecutor.""" |
|
summaries = [] |
|
total_topics = len(unique_selected_topics) |
|
|
|
|
|
progress_text = st.empty() |
|
progress_bar = st.progress(0) |
|
|
|
try: |
|
|
|
progress_text.text(f"Phase 1/3: Generating cluster summaries in parallel (0/{total_topics} completed)") |
|
completed_summaries = 0 |
|
|
|
with ThreadPoolExecutor(max_workers=max_workers) as executor: |
|
|
|
future_to_topic = { |
|
executor.submit( |
|
generate_raw_cluster_summary, |
|
topic_val, |
|
df_scope[df_scope['Topic'] == topic_val], |
|
llm, |
|
chat_prompt |
|
): topic_val |
|
for topic_val in unique_selected_topics |
|
} |
|
|
|
|
|
for future in future_to_topic: |
|
try: |
|
result = future.result() |
|
if result: |
|
summaries.append(result) |
|
completed_summaries += 1 |
|
|
|
progress = completed_summaries / total_topics |
|
progress_bar.progress(progress) |
|
progress_text.text( |
|
f"Phase 1/3: Generating cluster summaries in parallel ({completed_summaries}/{total_topics} completed)" |
|
) |
|
except Exception as e: |
|
topic_val = future_to_topic[future] |
|
st.error(f"Error in summary generation for cluster {topic_val}: {str(e)}") |
|
completed_summaries += 1 |
|
continue |
|
|
|
|
|
if enable_references and reference_id_column and summaries: |
|
total_to_enhance = len(summaries) |
|
completed_enhancements = 0 |
|
progress_text.text(f"Phase 2/3: Adding references to summaries (0/{total_to_enhance} completed)") |
|
progress_bar.progress(0) |
|
|
|
with ThreadPoolExecutor(max_workers=max_workers) as executor: |
|
|
|
future_to_summary = { |
|
executor.submit( |
|
enhance_summary_with_references, |
|
summary_dict, |
|
df_scope, |
|
reference_id_column, |
|
url_column, |
|
llm |
|
): summary_dict.get('Topic') |
|
for summary_dict in summaries |
|
} |
|
|
|
|
|
enhanced_summaries = [] |
|
for future in future_to_summary: |
|
try: |
|
result = future.result() |
|
if result: |
|
enhanced_summaries.append(result) |
|
completed_enhancements += 1 |
|
|
|
progress = completed_enhancements / total_to_enhance |
|
progress_bar.progress(progress) |
|
progress_text.text( |
|
f"Phase 2/3: Adding references to summaries ({completed_enhancements}/{total_to_enhance} completed)" |
|
) |
|
except Exception as e: |
|
topic_val = future_to_summary[future] |
|
st.error(f"Error in reference enhancement for cluster {topic_val}: {str(e)}") |
|
completed_enhancements += 1 |
|
continue |
|
|
|
summaries = enhanced_summaries |
|
|
|
|
|
if summaries: |
|
total_to_name = len(summaries) |
|
completed_names = 0 |
|
progress_text.text(f"Phase 3/3: Generating cluster names (0/{total_to_name} completed)") |
|
progress_bar.progress(0) |
|
|
|
with ThreadPoolExecutor(max_workers=max_workers) as executor: |
|
|
|
future_to_summary = { |
|
executor.submit( |
|
generate_cluster_name, |
|
summary_dict.get('Enhanced_Summary', summary_dict['Summary']), |
|
llm |
|
): summary_dict.get('Topic') |
|
for summary_dict in summaries |
|
} |
|
|
|
|
|
named_summaries = [] |
|
for future in future_to_summary: |
|
try: |
|
cluster_name = future.result() |
|
topic_val = future_to_summary[future] |
|
|
|
summary_dict = next(s for s in summaries if s['Topic'] == topic_val) |
|
summary_dict['Cluster_Name'] = cluster_name |
|
named_summaries.append(summary_dict) |
|
completed_names += 1 |
|
|
|
progress = completed_names / total_to_name |
|
progress_bar.progress(progress) |
|
progress_text.text( |
|
f"Phase 3/3: Generating cluster names ({completed_names}/{total_to_name} completed)" |
|
) |
|
except Exception as e: |
|
topic_val = future_to_summary[future] |
|
st.error(f"Error in cluster naming for cluster {topic_val}: {str(e)}") |
|
completed_names += 1 |
|
continue |
|
|
|
summaries = named_summaries |
|
finally: |
|
|
|
progress_text.empty() |
|
progress_bar.empty() |
|
|
|
return summaries |
|
|
|
|
|
|
|
|
|
def generate_cluster_name(summary_text: str, llm: Any) -> str: |
|
"""Generate a concise, descriptive name for a cluster based on its summary.""" |
|
system_prompt = """You are a cluster naming expert. Your task is to generate a very concise (3-6 words) but descriptive name for a cluster based on its summary. The name should capture the main theme or focus of the cluster. |
|
|
|
Rules: |
|
1. Keep it between 3-6 words |
|
2. Be specific but concise |
|
3. Capture the main theme/focus |
|
4. Use title case |
|
4. Do not include words like "Cluster", "Topic", or "Theme" |
|
5. Focus on the content, not metadata |
|
|
|
Example good names: |
|
- Agricultural Water Management Innovation |
|
- Gender Equality in Farming |
|
- Climate-Smart Village Implementation |
|
- Sustainable Livestock Practices""" |
|
|
|
messages = [ |
|
{"role": "system", "content": system_prompt}, |
|
{"role": "user", "content": f"Generate a concise cluster name based on this summary:\n\n{summary_text}"} |
|
] |
|
|
|
try: |
|
response = get_chat_response(messages) |
|
|
|
cluster_name = response.strip().strip('"').strip("'").strip() |
|
return cluster_name |
|
except Exception as e: |
|
st.error(f"Error generating cluster name: {str(e)}") |
|
return "Unnamed Cluster" |
|
|
|
|
|
|
|
|
|
def get_base_dir(): |
|
try: |
|
base_dir = os.path.dirname(__file__) |
|
if not base_dir: |
|
return os.getcwd() |
|
return base_dir |
|
except NameError: |
|
|
|
return os.getcwd() |
|
|
|
BASE_DIR = get_base_dir() |
|
|
|
|
|
|
|
|
|
def init_nltk_resources(): |
|
"""Initialize NLTK resources with better error handling and less verbose output""" |
|
nltk.data.path.append('/home/appuser/nltk_data') |
|
|
|
resources = { |
|
'tokenizers/punkt': 'punkt_tab', |
|
'corpora/stopwords': 'stopwords' |
|
} |
|
|
|
for resource_path, resource_name in resources.items(): |
|
try: |
|
nltk.data.find(resource_path) |
|
except LookupError: |
|
try: |
|
nltk.download(resource_name, quiet=True) |
|
except Exception as e: |
|
st.warning(f"Error downloading NLTK resource {resource_name}: {e}") |
|
|
|
|
|
try: |
|
from nltk.tokenize import PunktSentenceTokenizer |
|
tokenizer = PunktSentenceTokenizer() |
|
tokenizer.tokenize("Test sentence.") |
|
except Exception as e: |
|
st.error(f"Error initializing NLTK tokenizer: {e}") |
|
try: |
|
nltk.download('punkt_tab', quiet=True) |
|
except Exception as e: |
|
st.error(f"Failed to download punkt_tab tokenizer: {e}") |
|
|
|
|
|
init_nltk_resources() |
|
|
|
|
|
|
|
|
|
def add_references_to_summary(summary, source_df, reference_column, url_column=None, llm=None): |
|
""" |
|
Add references to a summary by identifying which parts of the summary come |
|
from which source documents. References will be appended as [ID], |
|
optionally linked if a URL column is provided. |
|
|
|
Args: |
|
summary (str): The summary text to enhance with references. |
|
source_df (DataFrame): DataFrame containing the source documents. |
|
reference_column (str): Column name to use for reference IDs. |
|
url_column (str, optional): Column name containing URLs for hyperlinks. |
|
llm (LLM, optional): Language model for source attribution. |
|
Returns: |
|
str: Enhanced summary with references as HTML if possible. |
|
""" |
|
if summary.strip() == "" or source_df.empty or reference_column not in source_df.columns: |
|
return summary |
|
|
|
|
|
if llm is None: |
|
return summary |
|
|
|
|
|
paragraphs = summary.split('\n\n') |
|
enhanced_paragraphs = [] |
|
|
|
|
|
source_texts = [] |
|
reference_ids = [] |
|
urls = [] |
|
for _, row in source_df.iterrows(): |
|
if 'text' in row and pd.notna(row['text']) and pd.notna(row[reference_column]): |
|
source_texts.append(str(row['text'])) |
|
reference_ids.append(str(row[reference_column])) |
|
if url_column and url_column in row and pd.notna(row[url_column]): |
|
urls.append(str(row[url_column])) |
|
else: |
|
urls.append(None) |
|
if not source_texts: |
|
return summary |
|
|
|
|
|
url_map = {} |
|
for ref_id, u in zip(reference_ids, urls): |
|
if u: |
|
url_map[ref_id] = u |
|
|
|
|
|
system_prompt = """ |
|
You are an expert at identifying the source of information. You will be given: |
|
1. A sentence or bullet point from a summary |
|
2. A list of source texts with their IDs |
|
|
|
Your task is to identify which source text(s) the text most likely came from. |
|
Return ONLY the IDs of the source texts that contributed to the text, separated by commas. |
|
If you cannot confidently attribute the text to any source, return "unknown". |
|
""" |
|
|
|
for paragraph in paragraphs: |
|
if not paragraph.strip(): |
|
enhanced_paragraphs.append('') |
|
continue |
|
|
|
|
|
if any(line.strip().startswith('- ') or line.strip().startswith('* ') for line in paragraph.split('\n')): |
|
|
|
bullet_lines = paragraph.split('\n') |
|
enhanced_bullets = [] |
|
for line in bullet_lines: |
|
if not line.strip(): |
|
enhanced_bullets.append(line) |
|
continue |
|
|
|
if line.strip().startswith('- ') or line.strip().startswith('* '): |
|
|
|
source_texts_formatted = '\n'.join([f"ID: {ref_id}, Text: {text[:500]}..." for ref_id, text in zip(reference_ids, source_texts)]) |
|
user_prompt = f""" |
|
Text: {line.strip()} |
|
|
|
Source texts: |
|
{source_texts_formatted} |
|
|
|
Which source ID(s) did this text most likely come from? Return only the ID(s) separated by commas, or "unknown". |
|
""" |
|
|
|
try: |
|
system_message = SystemMessagePromptTemplate.from_template(system_prompt) |
|
human_message = HumanMessagePromptTemplate.from_template({user_prompt}) |
|
chat_prompt = ChatPromptTemplate.from_messages([system_message, human_message]) |
|
chain = LLMChain(llm=llm, prompt=chat_prompt) |
|
response = chain.run(user_prompt=user_prompt) |
|
source_ids = response.strip() |
|
|
|
if source_ids.lower() == "unknown": |
|
enhanced_bullets.append(line) |
|
else: |
|
|
|
source_ids = re.sub(r'[^0-9,\s]', '', source_ids) |
|
source_ids = re.sub(r'\s+', '', source_ids) |
|
ids = [id_.strip() for id_ in source_ids.split(',') if id_.strip()] |
|
|
|
if ids: |
|
ref_parts = [] |
|
for id_ in ids: |
|
if id_ in url_map: |
|
ref_parts.append(f'<a href="{url_map[id_]}" target="_blank">{id_}</a>') |
|
else: |
|
ref_parts.append(id_) |
|
ref_string = ", ".join(ref_parts) |
|
enhanced_bullets.append(f"{line} [{ref_string}]") |
|
else: |
|
enhanced_bullets.append(line) |
|
except Exception: |
|
enhanced_bullets.append(line) |
|
else: |
|
enhanced_bullets.append(line) |
|
|
|
enhanced_paragraphs.append('\n'.join(enhanced_bullets)) |
|
else: |
|
|
|
sentences = re.split(r'(?<=[.!?])\s+', paragraph) |
|
enhanced_sentences = [] |
|
|
|
for sentence in sentences: |
|
if not sentence.strip(): |
|
continue |
|
|
|
source_texts_formatted = '\n'.join([f"ID: {ref_id}, Text: {text[:500]}..." for ref_id, text in zip(reference_ids, source_texts)]) |
|
user_prompt = f""" |
|
Sentence: {sentence.strip()} |
|
|
|
Source texts: |
|
{source_texts_formatted} |
|
|
|
Which source ID(s) did this sentence most likely come from? Return only the ID(s) separated by commas, or "unknown". |
|
""" |
|
|
|
try: |
|
system_message = SystemMessagePromptTemplate.from_template(system_prompt) |
|
human_message = HumanMessagePromptTemplate.from_template({user_prompt}) |
|
chat_prompt = ChatPromptTemplate.from_messages([system_message, human_message]) |
|
chain = LLMChain(llm=llm, prompt=chat_prompt) |
|
response = chain.run(user_prompt=user_prompt) |
|
source_ids = response.strip() |
|
|
|
if source_ids.lower() == "unknown": |
|
enhanced_sentences.append(sentence) |
|
else: |
|
|
|
source_ids = re.sub(r'[^0-9,\s]', '', source_ids) |
|
source_ids = re.sub(r'\s+', '', source_ids) |
|
ids = [id_.strip() for id_ in source_ids.split(',') if id_.strip()] |
|
|
|
if ids: |
|
ref_parts = [] |
|
for id_ in ids: |
|
if id_ in url_map: |
|
ref_parts.append(f'<a href="{url_map[id_]}" target="_blank">{id_}</a>') |
|
else: |
|
ref_parts.append(id_) |
|
ref_string = ", ".join(ref_parts) |
|
enhanced_sentences.append(f"{sentence} [{ref_string}]") |
|
else: |
|
enhanced_sentences.append(sentence) |
|
except Exception: |
|
enhanced_sentences.append(sentence) |
|
|
|
enhanced_paragraphs.append(' '.join(enhanced_sentences)) |
|
|
|
|
|
return '\n\n'.join(enhanced_paragraphs) |
|
|
|
|
|
st.sidebar.image("static/SNAP_logo.png", width=350) |
|
|
|
|
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
if device == 'cuda': |
|
st.sidebar.success(f"Using GPU: {torch.cuda.get_device_name(0)}") |
|
else: |
|
st.sidebar.info("Using CPU") |
|
|
|
|
|
|
|
|
|
@st.cache_resource |
|
def get_embedding_model(): |
|
model_dir = get_model_dir() |
|
st_model_dir = os.path.join(model_dir, 'sentence_transformer') |
|
os.makedirs(st_model_dir, exist_ok=True) |
|
|
|
model_name = 'all-MiniLM-L6-v2' |
|
try: |
|
|
|
model = SentenceTransformer(st_model_dir) |
|
|
|
except Exception as e: |
|
|
|
try: |
|
|
|
model = SentenceTransformer(model_name) |
|
model.save(st_model_dir) |
|
|
|
except Exception as download_e: |
|
st.error(f"Error downloading sentence transformer model: {str(download_e)}") |
|
raise |
|
|
|
return model.to(device) |
|
|
|
def generate_embeddings(texts, model): |
|
with st.spinner('Calculating embeddings...'): |
|
embeddings = model.encode(texts, show_progress_bar=True, device=device) |
|
return embeddings |
|
|
|
@st.cache_data |
|
def load_default_dataset(default_dataset_path): |
|
if os.path.exists(default_dataset_path): |
|
df_ = pd.read_excel(default_dataset_path) |
|
return df_ |
|
else: |
|
st.error("Default dataset not found. Please ensure the file exists in the 'input' directory.") |
|
return None |
|
|
|
@st.cache_data |
|
def load_uploaded_dataset(uploaded_file): |
|
df_ = pd.read_excel(uploaded_file) |
|
return df_ |
|
|
|
def load_or_compute_embeddings(df, using_default_dataset, uploaded_file_name=None, text_columns=None): |
|
""" |
|
Loads pre-computed embeddings from a pickle file if they match current data, |
|
otherwise computes and caches them. |
|
""" |
|
if not text_columns: |
|
return None, None |
|
|
|
base_name = "PRMS_2022_2023_2024_QAed" if using_default_dataset else "custom_dataset" |
|
if uploaded_file_name: |
|
base_name = os.path.splitext(uploaded_file_name)[0] |
|
|
|
cols_key = "_".join(sorted(text_columns)) |
|
timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
|
embeddings_dir = BASE_DIR |
|
if using_default_dataset: |
|
embeddings_file = os.path.join(embeddings_dir, f'{base_name}_{cols_key}.pkl') |
|
else: |
|
|
|
embeddings_file = os.path.join(embeddings_dir, f"{base_name}_{cols_key}.pkl") |
|
|
|
df_fill = df.fillna("") |
|
texts = df_fill[text_columns].astype(str).agg(' '.join, axis=1).tolist() |
|
|
|
|
|
if ('embeddings' in st.session_state |
|
and 'last_text_columns' in st.session_state |
|
and st.session_state['last_text_columns'] == text_columns |
|
and len(st.session_state['embeddings']) == len(texts)): |
|
return st.session_state['embeddings'], st.session_state.get('embeddings_file', None) |
|
|
|
|
|
if os.path.exists(embeddings_file): |
|
with open(embeddings_file, 'rb') as f: |
|
embeddings = pickle.load(f) |
|
if len(embeddings) == len(texts): |
|
st.write("Loaded pre-calculated embeddings.") |
|
st.session_state['embeddings'] = embeddings |
|
st.session_state['embeddings_file'] = embeddings_file |
|
st.session_state['last_text_columns'] = text_columns |
|
return embeddings, embeddings_file |
|
|
|
|
|
st.write("Generating embeddings...") |
|
model = get_embedding_model() |
|
embeddings = generate_embeddings(texts, model) |
|
with open(embeddings_file, 'wb') as f: |
|
pickle.dump(embeddings, f) |
|
|
|
st.session_state['embeddings'] = embeddings |
|
st.session_state['embeddings_file'] = embeddings_file |
|
st.session_state['last_text_columns'] = text_columns |
|
return embeddings, embeddings_file |
|
|
|
|
|
|
|
|
|
|
|
def reset_filters(): |
|
st.session_state['selected_additional_filters'] = {} |
|
|
|
|
|
st.sidebar.radio("Select view", ["Automatic Mode", "Power User Mode"], key="view") |
|
|
|
if st.session_state.view == "Power User Mode": |
|
st.header("Power User Mode") |
|
|
|
|
|
|
|
st.sidebar.title("Data Selection") |
|
dataset_option = st.sidebar.selectbox('Select Dataset', ('PRMS 2022+2023+2024 QAed', 'Upload my dataset')) |
|
|
|
if 'df' not in st.session_state: |
|
st.session_state['df'] = pd.DataFrame() |
|
if 'filtered_df' not in st.session_state: |
|
st.session_state['filtered_df'] = pd.DataFrame() |
|
|
|
if dataset_option == 'PRMS 2022+2023+2024 QAed': |
|
default_dataset_path = os.path.join(BASE_DIR, 'input', 'export_data_table_results_20251203_101413CET.xlsx') |
|
df = load_default_dataset(default_dataset_path) |
|
if df is not None: |
|
st.session_state['df'] = df.copy() |
|
st.session_state['using_default_dataset'] = True |
|
|
|
|
|
if 'filtered_df' not in st.session_state or st.session_state['filtered_df'].empty: |
|
st.session_state['filtered_df'] = df.copy() |
|
|
|
|
|
if 'filter_state' not in st.session_state: |
|
st.session_state['filter_state'] = { |
|
'applied': False, |
|
'filters': {} |
|
} |
|
|
|
|
|
if 'text_columns' not in st.session_state or not st.session_state['text_columns']: |
|
default_text_cols = [] |
|
if 'Title' in df.columns and 'Description' in df.columns: |
|
default_text_cols = ['Title', 'Description'] |
|
st.session_state['text_columns'] = default_text_cols |
|
|
|
|
|
|
|
|
|
|
|
|
|
df_cols = df.columns.tolist() |
|
|
|
|
|
st.subheader("Select Filters") |
|
if 'additional_filters_selected' not in st.session_state: |
|
st.session_state['additional_filters_selected'] = [] |
|
if 'filter_values' not in st.session_state: |
|
st.session_state['filter_values'] = {} |
|
|
|
with st.form("filter_selection_form"): |
|
all_columns = df.columns.tolist() |
|
selected_additional_cols = st.multiselect( |
|
"Select columns from your dataset to use as filters:", |
|
all_columns, |
|
default=st.session_state['additional_filters_selected'] |
|
) |
|
add_filters_submitted = st.form_submit_button("Add Additional Filters") |
|
|
|
if add_filters_submitted: |
|
if selected_additional_cols != st.session_state['additional_filters_selected']: |
|
st.session_state['additional_filters_selected'] = selected_additional_cols |
|
|
|
st.session_state['filter_values'] = { |
|
k: v for k, v in st.session_state['filter_values'].items() |
|
if k in selected_additional_cols |
|
} |
|
|
|
|
|
if st.session_state['additional_filters_selected']: |
|
st.subheader("Apply Filters") |
|
|
|
|
|
for col_name in st.session_state['additional_filters_selected']: |
|
unique_vals = sorted(df[col_name].dropna().unique().tolist()) |
|
|
|
|
|
search_key = f"search_{col_name}" |
|
if search_key not in st.session_state: |
|
st.session_state[search_key] = "" |
|
|
|
col1, col2 = st.columns([3, 1]) |
|
with col1: |
|
search_term = st.text_input( |
|
f"Search in {col_name}", |
|
key=search_key, |
|
help="Enter text to find and select all matching values" |
|
) |
|
with col2: |
|
if st.button(f"Select Matching", key=f"select_{col_name}"): |
|
|
|
if search_term: |
|
matching_vals = [ |
|
val for val in unique_vals |
|
if any(search_term.lower() in str(part).lower() |
|
for part in (val.split(',') if isinstance(val, str) else [val])) |
|
] |
|
|
|
current_selected = st.session_state['filter_values'].get(col_name, []) |
|
st.session_state['filter_values'][col_name] = list(set(current_selected + matching_vals)) |
|
|
|
|
|
if matching_vals: |
|
st.success(f"Found and selected {len(matching_vals)} matching values") |
|
else: |
|
st.warning("No matching values found") |
|
|
|
|
|
with st.form("apply_filters_form"): |
|
for col_name in st.session_state['additional_filters_selected']: |
|
unique_vals = sorted(df[col_name].dropna().unique().tolist()) |
|
selected_vals = st.multiselect( |
|
f"Filter by {col_name}", |
|
options=unique_vals, |
|
default=st.session_state['filter_values'].get(col_name, []) |
|
) |
|
st.session_state['filter_values'][col_name] = selected_vals |
|
|
|
|
|
col1, col2 = st.columns([1, 4]) |
|
with col1: |
|
clear_filters = st.form_submit_button("Clear All") |
|
with col2: |
|
apply_filters_submitted = st.form_submit_button("Apply Filters to Dataset") |
|
|
|
if clear_filters: |
|
st.session_state['filter_values'] = {} |
|
|
|
if 'summary_df' in st.session_state: |
|
del st.session_state['summary_df'] |
|
if 'high_level_summary' in st.session_state: |
|
del st.session_state['high_level_summary'] |
|
if 'enhanced_summary' in st.session_state: |
|
del st.session_state['enhanced_summary'] |
|
st.rerun() |
|
|
|
|
|
with st.expander("βοΈ Advanced Settings", expanded=False): |
|
st.subheader("**Select Text Columns for Embedding**") |
|
text_columns_selected = st.multiselect( |
|
"Text Columns:", |
|
df_cols, |
|
default=st.session_state['text_columns'], |
|
help="Choose columns containing text for semantic search and clustering. " |
|
"If multiple are selected, their text will be concatenated." |
|
) |
|
st.session_state['text_columns'] = text_columns_selected |
|
|
|
|
|
filtered_df = df.copy() |
|
if 'apply_filters_submitted' in locals() and apply_filters_submitted: |
|
|
|
if 'summary_df' in st.session_state: |
|
del st.session_state['summary_df'] |
|
if 'high_level_summary' in st.session_state: |
|
del st.session_state['high_level_summary'] |
|
if 'enhanced_summary' in st.session_state: |
|
del st.session_state['enhanced_summary'] |
|
|
|
for col_name in st.session_state['additional_filters_selected']: |
|
selected_vals = st.session_state['filter_values'].get(col_name, []) |
|
if selected_vals: |
|
filtered_df = filtered_df[filtered_df[col_name].isin(selected_vals)] |
|
st.success("Filters applied successfully!") |
|
st.session_state['filtered_df'] = filtered_df.copy() |
|
st.session_state['filter_state'] = { |
|
'applied': True, |
|
'filters': st.session_state['filter_values'].copy() |
|
} |
|
|
|
for k in ['clustered_data', 'topic_model', 'current_clustering_data', |
|
'current_clustering_option', 'hierarchy']: |
|
if k in st.session_state: |
|
del st.session_state[k] |
|
|
|
elif 'filter_state' in st.session_state and st.session_state['filter_state']['applied']: |
|
|
|
for col_name, selected_vals in st.session_state['filter_state']['filters'].items(): |
|
if selected_vals: |
|
filtered_df = filtered_df[filtered_df[col_name].isin(selected_vals)] |
|
st.session_state['filtered_df'] = filtered_df.copy() |
|
|
|
|
|
if st.session_state['filtered_df'] is not None: |
|
if st.session_state['filter_state']['applied']: |
|
st.write("Filtered Data Preview:") |
|
else: |
|
st.write("Current Data Preview:") |
|
st.dataframe(st.session_state['filtered_df'].head(), hide_index=True) |
|
st.write(f"Total number of results: {len(st.session_state['filtered_df'])}") |
|
|
|
output = io.BytesIO() |
|
writer = pd.ExcelWriter(output, engine='openpyxl') |
|
st.session_state['filtered_df'].to_excel(writer, index=False) |
|
writer.close() |
|
processed_data = output.getvalue() |
|
|
|
st.download_button( |
|
label="Download Current Data", |
|
data=processed_data, |
|
file_name='data.xlsx', |
|
mime='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet' |
|
) |
|
else: |
|
st.warning("Please ensure the default dataset exists in the 'input' directory.") |
|
|
|
else: |
|
|
|
uploaded_file = st.sidebar.file_uploader("Upload your Excel file", type=["xlsx"]) |
|
if uploaded_file is not None: |
|
df = load_uploaded_dataset(uploaded_file) |
|
if df is not None: |
|
st.session_state['df'] = df.copy() |
|
st.session_state['using_default_dataset'] = False |
|
st.session_state['uploaded_file_name'] = uploaded_file.name |
|
st.write("Data preview:") |
|
st.write(df.head()) |
|
df_cols = df.columns.tolist() |
|
|
|
st.subheader("**Select Text Columns for Embedding**") |
|
text_columns_selected = st.multiselect( |
|
"Text Columns:", |
|
df_cols, |
|
default=df_cols[:1] if df_cols else [] |
|
) |
|
st.session_state['text_columns'] = text_columns_selected |
|
|
|
st.write("**Additional Filters**") |
|
selected_additional_cols = st.multiselect( |
|
"Select additional columns from your dataset to use as filters:", |
|
df_cols, |
|
default=[] |
|
) |
|
st.session_state['additional_filters_selected'] = selected_additional_cols |
|
|
|
filtered_df = df.copy() |
|
for col_name in selected_additional_cols: |
|
if f'selected_filter_{col_name}' not in st.session_state: |
|
st.session_state[f'selected_filter_{col_name}'] = [] |
|
unique_vals = sorted(df[col_name].dropna().unique().tolist()) |
|
selected_vals = st.multiselect( |
|
f"Filter by {col_name}", |
|
options=unique_vals, |
|
default=st.session_state[f'selected_filter_{col_name}'] |
|
) |
|
st.session_state[f'selected_filter_{col_name}'] = selected_vals |
|
if selected_vals: |
|
filtered_df = filtered_df[filtered_df[col_name].isin(selected_vals)] |
|
|
|
st.session_state['filtered_df'] = filtered_df |
|
st.write("Filtered Data Preview:") |
|
st.dataframe(filtered_df.head(), hide_index=True) |
|
st.write(f"Total number of results: {len(filtered_df)}") |
|
|
|
output = io.BytesIO() |
|
writer = pd.ExcelWriter(output, engine='openpyxl') |
|
filtered_df.to_excel(writer, index=False) |
|
writer.close() |
|
processed_data = output.getvalue() |
|
|
|
st.download_button( |
|
label="Download Filtered Data", |
|
data=processed_data, |
|
file_name='filtered_data.xlsx', |
|
mime='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet' |
|
) |
|
else: |
|
st.warning("Failed to load the uploaded dataset.") |
|
else: |
|
st.warning("Please upload an Excel file to proceed.") |
|
|
|
if 'filtered_df' in st.session_state: |
|
st.write(f"Total number of results: {len(st.session_state['filtered_df'])}") |
|
|
|
|
|
|
|
|
|
|
|
if 'active_tab_index' not in st.session_state: |
|
st.session_state.active_tab_index = 0 |
|
|
|
tabs_titles = ["Semantic Search", "Clustering", "Summarization", "Chat", "Help"] |
|
tabs = st.tabs(tabs_titles) |
|
|
|
tab_semantic, tab_clustering, tab_summarization, tab_chat, tab_help = tabs |
|
|
|
|
|
|
|
|
|
with tab_help: |
|
st.header("Help") |
|
st.markdown(""" |
|
### About SNAP |
|
|
|
SNAP allows you to explore, filter, search, cluster, and summarize textual datasets. |
|
|
|
**Workflow**: |
|
1. **Data Selection (Sidebar)**: Choose the default dataset or upload your own. |
|
2. **Filtering**: Set additional filters for your dataset. |
|
3. **Select Text Columns**: Which columns to embed. |
|
4. **Semantic Search** (Tab): Provide a query and threshold to find relevant documents. |
|
5. **Clustering** (Tab): Group documents into topics. |
|
6. **Summarization** (Tab): Summarize the clustered documents (with optional references). |
|
|
|
### Troubleshooting |
|
- If you see no results, try lowering the similarity threshold or removing negative/required keywords. |
|
- Ensure you have at least one text column selected for embeddings. |
|
""") |
|
|
|
|
|
|
|
|
|
with tab_semantic: |
|
st.header("Semantic Search") |
|
if 'filtered_df' in st.session_state and not st.session_state['filtered_df'].empty: |
|
text_columns = st.session_state.get('text_columns', []) |
|
if not text_columns: |
|
st.warning("No text columns selected. Please select at least one column for text embedding.") |
|
else: |
|
df_full = st.session_state['df'] |
|
|
|
embeddings, _ = load_or_compute_embeddings( |
|
df_full, |
|
st.session_state.get('using_default_dataset', False), |
|
st.session_state.get('uploaded_file_name'), |
|
text_columns |
|
) |
|
|
|
if embeddings is not None: |
|
with st.expander("βΉοΈ How Semantic Search Works", expanded=False): |
|
st.markdown(""" |
|
### Understanding Semantic Search |
|
|
|
Unlike traditional keyword search that looks for exact matches, semantic search understands the meaning and context of your query. Here's how it works: |
|
|
|
1. **Query Processing**: |
|
- Your search query is converted into a numerical representation (embedding) that captures its meaning |
|
- Example: Searching for "Climate Smart Villages" will understand the concept, not just the words |
|
- Related terms like "sustainable communities", "resilient farming", or "agricultural adaptation" might be found even if they don't contain the exact words |
|
|
|
2. **Similarity Matching**: |
|
- Documents are ranked by how closely their meaning matches your query |
|
- The similarity threshold controls how strict this matching is |
|
- Higher threshold (e.g., 0.8) = more precise but fewer results |
|
- Lower threshold (e.g., 0.3) = more results but might be less relevant |
|
|
|
3. **Advanced Features**: |
|
- **Negative Keywords**: Use to explicitly exclude documents containing certain terms |
|
- **Required Keywords**: Ensure specific terms appear in the results |
|
- These work as traditional keyword filters after the semantic search |
|
|
|
### Search Tips |
|
|
|
- **Phrase Queries**: Enter complete phrases for better context |
|
- "Climate Smart Villages" (as one concept) |
|
- Better than separate terms: "climate", "smart", "villages" |
|
|
|
- **Descriptive Queries**: Add context for better results |
|
- Instead of: "water" |
|
- Better: "water management in agriculture" |
|
|
|
- **Conceptual Queries**: Focus on concepts rather than specific terms |
|
- Instead of: "increased yield" |
|
- Better: "agricultural productivity improvements" |
|
|
|
### Example Searches |
|
|
|
1. **Query**: "Climate Smart Villages" |
|
- Will find: Documents about climate-resilient communities, adaptive farming practices, sustainable village development |
|
- Even if they don't use these exact words |
|
|
|
2. **Query**: "Gender equality in agriculture" |
|
- Will find: Women's empowerment in farming, female farmer initiatives, gender-inclusive rural development |
|
- Related concepts are captured semantically |
|
|
|
3. **Query**: "Sustainable water management" |
|
+ Required keyword: "irrigation" |
|
- Combines semantic understanding of water sustainability with specific irrigation focus |
|
""") |
|
|
|
with st.form("search_parameters"): |
|
query = st.text_input("Enter your search query:") |
|
include_keywords = st.text_input("Include only documents containing these words (comma-separated):") |
|
similarity_threshold = st.slider("Similarity threshold", 0.0, 1.0, 0.35) |
|
submitted = st.form_submit_button("Search") |
|
|
|
if submitted: |
|
if query.strip(): |
|
with st.spinner("Performing Semantic Search..."): |
|
|
|
if 'summary_df' in st.session_state: |
|
del st.session_state['summary_df'] |
|
if 'high_level_summary' in st.session_state: |
|
del st.session_state['high_level_summary'] |
|
if 'enhanced_summary' in st.session_state: |
|
del st.session_state['enhanced_summary'] |
|
|
|
model = get_embedding_model() |
|
df_filtered = st.session_state['filtered_df'].fillna("") |
|
search_texts = df_filtered[text_columns].agg(' '.join, axis=1).tolist() |
|
|
|
|
|
subset_indices = df_filtered.index |
|
subset_embeddings = embeddings[subset_indices] |
|
|
|
query_embedding = model.encode([query], device=device) |
|
similarities = cosine_similarity(query_embedding, subset_embeddings)[0] |
|
|
|
|
|
fig = px.histogram( |
|
x=similarities, |
|
nbins=30, |
|
labels={'x': 'Similarity Score', 'y': 'Number of Documents'}, |
|
title='Distribution of Similarity Scores' |
|
) |
|
fig.add_vline( |
|
x=similarity_threshold, |
|
line_dash="dash", |
|
line_color="red", |
|
annotation_text=f"Threshold: {similarity_threshold:.2f}", |
|
annotation_position="top" |
|
) |
|
st.write("### Similarity Score Distribution") |
|
st.plotly_chart(fig) |
|
|
|
above_threshold_indices = np.where(similarities > similarity_threshold)[0] |
|
if len(above_threshold_indices) == 0: |
|
st.warning("No results found above the similarity threshold.") |
|
if 'search_results' in st.session_state: |
|
del st.session_state['search_results'] |
|
else: |
|
selected_indices = subset_indices[above_threshold_indices] |
|
results = df_filtered.loc[selected_indices].copy() |
|
results['similarity_score'] = similarities[above_threshold_indices] |
|
results.sort_values(by='similarity_score', ascending=False, inplace=True) |
|
|
|
|
|
if include_keywords.strip(): |
|
inc_words = [w.strip().lower() for w in include_keywords.split(',') if w.strip()] |
|
if inc_words: |
|
results = results[ |
|
results.apply( |
|
lambda row: all( |
|
w in (' '.join(row.astype(str)).lower()) for w in inc_words |
|
), |
|
axis=1 |
|
) |
|
] |
|
|
|
if results.empty: |
|
st.warning("No results found after applying keyword filters.") |
|
if 'search_results' in st.session_state: |
|
del st.session_state['search_results'] |
|
else: |
|
st.session_state['search_results'] = results.copy() |
|
output = io.BytesIO() |
|
writer = pd.ExcelWriter(output, engine='openpyxl') |
|
results.to_excel(writer, index=False) |
|
writer.close() |
|
processed_data = output.getvalue() |
|
st.session_state['search_results_processed_data'] = processed_data |
|
else: |
|
st.warning("Please enter a query to search.") |
|
|
|
|
|
if 'search_results' in st.session_state and not st.session_state['search_results'].empty: |
|
st.write("## Search Results") |
|
results = st.session_state['search_results'] |
|
cols_to_display = [c for c in results.columns if c != 'similarity_score'] + ['similarity_score'] |
|
st.dataframe(results[cols_to_display], hide_index=True) |
|
st.write(f"Total number of results: {len(results)}") |
|
|
|
if 'search_results_processed_data' in st.session_state: |
|
st.download_button( |
|
label="Download Full Results", |
|
data=st.session_state['search_results_processed_data'], |
|
file_name='search_results.xlsx', |
|
mime='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', |
|
key='download_search_results' |
|
) |
|
else: |
|
st.info("No search results to display. Enter a query and click 'Search'.") |
|
else: |
|
st.warning("No embeddings available because no text columns were chosen.") |
|
else: |
|
st.warning("Filtered dataset is empty or not loaded. Please adjust your filters or upload data.") |
|
|
|
|
|
|
|
|
|
|
|
with tab_clustering: |
|
st.header("Clustering") |
|
if 'filtered_df' in st.session_state and not st.session_state['filtered_df'].empty: |
|
|
|
with st.expander("βΉοΈ How Clustering Works", expanded=False): |
|
st.markdown(""" |
|
### Understanding Document Clustering |
|
|
|
Clustering automatically groups similar documents together, helping you discover patterns and themes in your data. Here's how it works: |
|
|
|
1. **Cluster Formation**: |
|
- Documents are grouped based on their semantic similarity |
|
- Each cluster represents a distinct theme or topic |
|
- Documents that are too different from others may remain unclustered (labeled as -1) |
|
- The "Min Cluster Size" parameter controls how clusters are formed |
|
|
|
2. **Interpreting Results**: |
|
- Each cluster is assigned a number (e.g., 0, 1, 2...) |
|
- Cluster -1 contains "outlier" documents that didn't fit well in other clusters |
|
- The size of each cluster indicates how common that theme is |
|
- Keywords for each cluster show the main topics/concepts |
|
|
|
3. **Visualizations**: |
|
- **Intertopic Distance Map**: Shows how clusters relate to each other |
|
- Closer clusters are more semantically similar |
|
- Size of circles indicates number of documents |
|
- Hover to see top terms for each cluster |
|
|
|
- **Topic Document Visualization**: Shows individual documents |
|
- Each point is a document |
|
- Colors indicate cluster membership |
|
- Distance between points shows similarity |
|
|
|
- **Topic Hierarchy**: Shows how topics are related |
|
- Tree structure shows topic relationships |
|
- Parent topics contain broader themes |
|
- Child topics show more specific sub-themes |
|
|
|
### How to Use Clusters |
|
|
|
1. **Exploration**: |
|
- Use clusters to discover main themes in your data |
|
- Look for unexpected groupings that might reveal insights |
|
- Identify outliers that might need special attention |
|
|
|
2. **Analysis**: |
|
- Compare cluster sizes to understand theme distribution |
|
- Examine keywords to understand what defines each cluster |
|
- Use hierarchy to see how themes are nested |
|
|
|
3. **Practical Applications**: |
|
- Generate summaries for specific clusters |
|
- Focus detailed analysis on clusters of interest |
|
- Use clusters to organize and categorize documents |
|
- Identify gaps or overlaps in your dataset |
|
|
|
### Tips for Better Results |
|
|
|
- **Adjust Min Cluster Size**: |
|
- Larger values (15-20): Fewer, broader clusters |
|
- Smaller values (2-5): More specific, smaller clusters |
|
- Balance between too many small clusters and too few large ones |
|
|
|
- **Choose Data Wisely**: |
|
- Cluster full dataset for overall themes |
|
- Cluster search results for focused analysis |
|
- More documents generally give better clusters |
|
|
|
- **Interpret with Context**: |
|
- Consider your domain knowledge |
|
- Look for patterns across multiple visualizations |
|
- Use cluster insights to guide further analysis |
|
""") |
|
|
|
df_to_cluster = None |
|
|
|
|
|
with st.form("clustering_form"): |
|
st.subheader("Clustering Settings") |
|
|
|
|
|
clustering_option = st.radio( |
|
"Select data for clustering:", |
|
('Full Dataset', 'Filtered Dataset', 'Semantic Search Results') |
|
) |
|
|
|
|
|
min_cluster_size_val = st.slider( |
|
"Min Cluster Size", |
|
min_value=2, |
|
max_value=50, |
|
value=st.session_state.get('min_cluster_size', 5), |
|
help="Minimum size of each cluster in HDBSCAN; In other words, it's the minimum number of documents/texts that must be grouped together to form a valid cluster.\n\n- A larger value (e.g., 20) will result in fewer, larger clusters\n- A smaller value (e.g., 2-5) will allow for more clusters, including smaller ones\n- Documents that don't fit into any cluster meeting this minimum size requirement are labeled as noise (typically assigned to cluster -1)" |
|
) |
|
|
|
run_clustering = st.form_submit_button("Run Clustering") |
|
|
|
if run_clustering: |
|
st.session_state.active_tab_index = tabs_titles.index("Clustering") |
|
st.session_state['min_cluster_size'] = min_cluster_size_val |
|
|
|
|
|
if clustering_option == 'Semantic Search Results': |
|
if 'search_results' in st.session_state and not st.session_state['search_results'].empty: |
|
df_to_cluster = st.session_state['search_results'].copy() |
|
else: |
|
st.warning("No semantic search results found. Please run a search first.") |
|
elif clustering_option == 'Filtered Dataset': |
|
if 'filtered_df' in st.session_state and not st.session_state['filtered_df'].empty: |
|
df_to_cluster = st.session_state['filtered_df'].copy() |
|
else: |
|
st.warning("Filtered dataset is empty. Please check your filters.") |
|
else: |
|
if 'df' in st.session_state and not st.session_state['df'].empty: |
|
df_to_cluster = st.session_state['df'].copy() |
|
|
|
text_columns = st.session_state.get('text_columns', []) |
|
if not text_columns: |
|
st.warning("No text columns selected. Please select text columns to embed before clustering.") |
|
else: |
|
|
|
df_full = st.session_state['df'] |
|
embeddings, _ = load_or_compute_embeddings( |
|
df_full, |
|
st.session_state.get('using_default_dataset', False), |
|
st.session_state.get('uploaded_file_name'), |
|
text_columns |
|
) |
|
|
|
if df_to_cluster is not None and embeddings is not None and not df_to_cluster.empty and run_clustering: |
|
with st.spinner("Performing clustering..."): |
|
|
|
if 'summary_df' in st.session_state: |
|
del st.session_state['summary_df'] |
|
if 'high_level_summary' in st.session_state: |
|
del st.session_state['high_level_summary'] |
|
if 'enhanced_summary' in st.session_state: |
|
del st.session_state['enhanced_summary'] |
|
|
|
dfc = df_to_cluster.copy().fillna("") |
|
dfc['text'] = dfc[text_columns].astype(str).agg(' '.join, axis=1) |
|
|
|
|
|
selected_indices = dfc.index |
|
embeddings_clustering = embeddings[selected_indices] |
|
|
|
|
|
stop_words = set(stopwords.words('english')) |
|
texts_cleaned = [] |
|
for text in dfc['text'].tolist(): |
|
try: |
|
|
|
try: |
|
word_tokens = word_tokenize(text) |
|
except LookupError: |
|
|
|
nltk.download('punkt_tab', quiet=False) |
|
word_tokens = word_tokenize(text) |
|
except Exception as e: |
|
|
|
st.warning(f"Using fallback tokenization due to error: {e}") |
|
word_tokens = text.split() |
|
|
|
filtered_text = ' '.join([w for w in word_tokens if w.lower() not in stop_words]) |
|
texts_cleaned.append(filtered_text) |
|
except Exception as e: |
|
st.error(f"Error processing text: {e}") |
|
|
|
texts_cleaned.append(text) |
|
|
|
try: |
|
|
|
if len(texts_cleaned) < min_cluster_size_val: |
|
st.error(f"Not enough documents to form clusters. You have {len(texts_cleaned)} documents but minimum cluster size is set to {min_cluster_size_val}.") |
|
st.session_state['clustering_error'] = "Insufficient documents for clustering" |
|
st.session_state.active_tab_index = tabs_titles.index("Clustering") |
|
st.stop() |
|
|
|
|
|
if torch.is_tensor(embeddings_clustering): |
|
embeddings_for_clustering = embeddings_clustering.cpu().numpy() |
|
else: |
|
embeddings_for_clustering = embeddings_clustering |
|
|
|
|
|
if embeddings_for_clustering.shape[0] != len(texts_cleaned): |
|
st.error("Mismatch between number of embeddings and texts.") |
|
st.session_state['clustering_error'] = "Embedding and text count mismatch" |
|
st.session_state.active_tab_index = tabs_titles.index("Clustering") |
|
st.stop() |
|
|
|
|
|
try: |
|
hdbscan_model = HDBSCAN( |
|
min_cluster_size=min_cluster_size_val, |
|
metric='euclidean', |
|
cluster_selection_method='eom' |
|
) |
|
|
|
|
|
topic_model = BERTopic( |
|
embedding_model=get_embedding_model(), |
|
hdbscan_model=hdbscan_model |
|
) |
|
|
|
|
|
topics, probs = topic_model.fit_transform( |
|
texts_cleaned, |
|
embeddings=embeddings_for_clustering |
|
) |
|
|
|
|
|
unique_topics = set(topics) |
|
if len(unique_topics) < 2: |
|
st.warning("Clustering resulted in too few clusters. Retry or try reducing the minimum cluster size.") |
|
if -1 in unique_topics: |
|
non_noise_docs = sum(1 for t in topics if t != -1) |
|
st.info(f"Only {non_noise_docs} documents were assigned to clusters. The rest were marked as noise (-1).") |
|
if non_noise_docs < min_cluster_size_val: |
|
st.error("Not enough documents were successfully clustered. Try reducing the minimum cluster size.") |
|
st.session_state['clustering_error'] = "Insufficient clustered documents" |
|
st.session_state.active_tab_index = tabs_titles.index("Clustering") |
|
st.stop() |
|
|
|
|
|
dfc['Topic'] = topics |
|
st.session_state['topic_model'] = topic_model |
|
st.session_state['clustered_data'] = dfc.copy() |
|
st.session_state['clustering_texts_cleaned'] = texts_cleaned |
|
st.session_state['clustering_embeddings'] = embeddings_for_clustering |
|
st.session_state['clustering_completed'] = True |
|
|
|
|
|
try: |
|
st.session_state['intertopic_distance_fig'] = topic_model.visualize_topics() |
|
except Exception as viz_error: |
|
st.warning("Could not generate topic visualization. This usually happens when there are too few total clusters. Try adjusting the minimum cluster size or adding more documents.") |
|
st.session_state['intertopic_distance_fig'] = None |
|
|
|
try: |
|
st.session_state['topic_document_fig'] = topic_model.visualize_documents( |
|
texts_cleaned, |
|
embeddings=embeddings_for_clustering |
|
) |
|
except Exception as viz_error: |
|
st.warning("Could not generate document visualization. This might happen when the clustering results are not optimal. Try adjusting the clustering parameters.") |
|
st.session_state['topic_document_fig'] = None |
|
|
|
try: |
|
hierarchy = topic_model.hierarchical_topics(texts_cleaned) |
|
st.session_state['hierarchy'] = hierarchy if hierarchy is not None else pd.DataFrame() |
|
st.session_state['hierarchy_fig'] = topic_model.visualize_hierarchy() |
|
except Exception as viz_error: |
|
st.warning("Could not generate topic hierarchy visualization. This usually happens when there aren't enough distinct topics to form a hierarchy.") |
|
st.session_state['hierarchy'] = pd.DataFrame() |
|
st.session_state['hierarchy_fig'] = None |
|
|
|
except ValueError as ve: |
|
if "zero-size array to reduction operation maximum which has no identity" in str(ve): |
|
st.error("Clustering failed: No valid clusters could be formed. Try reducing the minimum cluster size.") |
|
elif "Cannot use scipy.linalg.eigh for sparse A with k > N" in str(ve): |
|
st.error("Clustering failed: Too many components requested for the number of documents. Try with more documents or adjust clustering parameters.") |
|
else: |
|
st.error(f"Clustering error: {str(ve)}") |
|
st.session_state['clustering_error'] = str(ve) |
|
st.session_state.active_tab_index = tabs_titles.index("Clustering") |
|
st.stop() |
|
|
|
except Exception as e: |
|
st.error(f"An error occurred during clustering: {str(e)}") |
|
st.session_state['clustering_error'] = str(e) |
|
st.session_state['clustering_completed'] = False |
|
st.session_state.active_tab_index = tabs_titles.index("Clustering") |
|
st.stop() |
|
|
|
|
|
if st.session_state.get('clustering_completed', False): |
|
st.subheader("Topic Overview") |
|
dfc = st.session_state['clustered_data'] |
|
topic_model = st.session_state['topic_model'] |
|
topics = dfc['Topic'].tolist() |
|
|
|
unique_topics = sorted(list(set(topics))) |
|
cluster_info = [] |
|
for t in unique_topics: |
|
cluster_docs = dfc[dfc['Topic'] == t] |
|
count = len(cluster_docs) |
|
top_words = topic_model.get_topic(t) |
|
if top_words: |
|
top_keywords = ", ".join([w[0] for w in top_words[:5]]) |
|
else: |
|
top_keywords = "N/A" |
|
cluster_info.append((t, count, top_keywords)) |
|
cluster_df = pd.DataFrame(cluster_info, columns=["Topic", "Count", "Top Keywords"]) |
|
|
|
st.write("### Topic Overview") |
|
st.dataframe( |
|
cluster_df, |
|
column_config={ |
|
"Topic": st.column_config.NumberColumn("Topic", help="Topic ID (-1 represents outliers)"), |
|
"Count": st.column_config.NumberColumn("Count", help="Number of documents in this topic"), |
|
"Top Keywords": st.column_config.TextColumn( |
|
"Top Keywords", |
|
help="Top 5 keywords that characterize this topic" |
|
) |
|
}, |
|
hide_index=True |
|
) |
|
|
|
st.subheader("Clustering Results") |
|
columns_to_display = [c for c in dfc.columns if c != 'text'] |
|
st.dataframe(dfc[columns_to_display], hide_index=True) |
|
|
|
|
|
st.write("### Intertopic Distance Map") |
|
if st.session_state.get('intertopic_distance_fig') is not None: |
|
try: |
|
st.plotly_chart(st.session_state['intertopic_distance_fig']) |
|
except Exception: |
|
st.info("Topic visualization is not available for the current clustering results.") |
|
|
|
st.write("### Topic Document Visualization") |
|
if st.session_state.get('topic_document_fig') is not None: |
|
try: |
|
st.plotly_chart(st.session_state['topic_document_fig']) |
|
except Exception: |
|
st.info("Document visualization is not available for the current clustering results.") |
|
|
|
st.write("### Topic Hierarchy") |
|
if st.session_state.get('hierarchy_fig') is not None: |
|
try: |
|
st.plotly_chart(st.session_state['hierarchy_fig']) |
|
except Exception: |
|
st.info("Topic hierarchy visualization is not available for the current clustering results.") |
|
if not (df_to_cluster is not None and embeddings is not None and not df_to_cluster.empty and run_clustering): |
|
pass |
|
else: |
|
st.warning("Please select or upload a dataset and filter as needed.") |
|
|
|
|
|
|
|
|
|
|
|
with tab_summarization: |
|
st.header("Summarization") |
|
|
|
with st.expander("βΉοΈ How Summarization Works", expanded=False): |
|
st.markdown(""" |
|
### Understanding Document Summarization |
|
|
|
Summarization condenses multiple documents into concise, meaningful summaries while preserving key information. Here's how it works: |
|
|
|
1. **Summary Generation**: |
|
- Documents are processed using advanced language models |
|
- Key themes and important points are identified |
|
- Content is condensed while maintaining context |
|
- Both high-level and cluster-specific summaries are available |
|
|
|
2. **Reference System**: |
|
- Summaries can include references to source documents |
|
- References are shown as [ID] or as clickable links |
|
- Each statement can be traced back to its source |
|
- Helps maintain accountability and verification |
|
|
|
3. **Types of Summaries**: |
|
- **High-Level Summary**: Overview of all selected documents |
|
- Captures main themes across the entire selection |
|
- Ideal for quick understanding of large document sets |
|
- Shows relationships between different topics |
|
|
|
- **Cluster-Specific Summaries**: Focused on each cluster |
|
- More detailed for specific themes |
|
- Shows unique aspects of each cluster |
|
- Helps understand sub-topics in depth |
|
|
|
### How to Use Summaries |
|
|
|
1. **Configuration**: |
|
- Choose between all clusters or specific ones |
|
- Set temperature for creativity vs. consistency |
|
- Adjust max tokens for summary length |
|
- Enable/disable reference system |
|
|
|
2. **Reference Options**: |
|
- Select column for reference IDs |
|
- Add hyperlinks to references |
|
- Choose URL column for clickable links |
|
- References help track information sources |
|
|
|
3. **Practical Applications**: |
|
- Quick overview of large datasets |
|
- Detailed analysis of specific themes |
|
- Evidence-based reporting with references |
|
- Compare different document groups |
|
|
|
### Tips for Better Results |
|
|
|
- **Temperature Setting**: |
|
- Higher (0.7-1.0): More creative, varied summaries |
|
- Lower (0.1-0.3): More consistent, conservative summaries |
|
- Balance based on your needs for creativity vs. consistency |
|
|
|
- **Token Length**: |
|
- Longer limits: More detailed summaries |
|
- Shorter limits: More concise, focused summaries |
|
- Adjust based on document complexity |
|
|
|
- **Reference Usage**: |
|
- Enable references for traceability |
|
- Use hyperlinks for easy navigation |
|
- Choose meaningful reference columns |
|
- Helps validate summary accuracy |
|
|
|
### Best Practices |
|
|
|
1. **For General Overview**: |
|
- Use high-level summary |
|
- Keep temperature moderate (0.5-0.7) |
|
- Enable references for verification |
|
- Focus on broader themes |
|
|
|
2. **For Detailed Analysis**: |
|
- Use cluster-specific summaries |
|
- Adjust temperature based on need |
|
- Include references with hyperlinks |
|
- Look for patterns within clusters |
|
|
|
3. **For Reporting**: |
|
- Combine both summary types |
|
- Use references extensively |
|
- Balance detail and brevity |
|
- Ensure source traceability |
|
""") |
|
|
|
df_summ = None |
|
|
|
if 'clustered_data' in st.session_state and not st.session_state['clustered_data'].empty: |
|
df_summ = st.session_state['clustered_data'] |
|
elif 'filtered_df' in st.session_state and not st.session_state['filtered_df'].empty: |
|
df_summ = st.session_state['filtered_df'] |
|
else: |
|
st.warning("No data available for summarization. Please cluster first or have some filtered data.") |
|
|
|
if df_summ is not None and not df_summ.empty: |
|
text_columns = st.session_state.get('text_columns', []) |
|
if not text_columns: |
|
st.warning("No text columns selected. Please select columns for text embedding first.") |
|
else: |
|
if 'Topic' not in df_summ.columns or 'topic_model' not in st.session_state: |
|
st.warning("No 'Topic' column found. Summaries per cluster are only available if you've run clustering.") |
|
else: |
|
topic_model = st.session_state['topic_model'] |
|
df_summ['text'] = df_summ.fillna("").astype(str)[text_columns].agg(' '.join, axis=1) |
|
|
|
|
|
topics = sorted(df_summ['Topic'].unique()) |
|
cluster_info = [] |
|
for t in topics: |
|
cluster_docs = df_summ[df_summ['Topic'] == t] |
|
count = len(cluster_docs) |
|
top_words = topic_model.get_topic(t) |
|
if top_words: |
|
top_keywords = ", ".join([w[0] for w in top_words[:5]]) |
|
else: |
|
top_keywords = "N/A" |
|
cluster_info.append((t, count, top_keywords)) |
|
cluster_df = pd.DataFrame(cluster_info, columns=["Topic", "Count", "Top Keywords"]) |
|
|
|
|
|
if 'summary_df' in st.session_state and 'Cluster_Name' in st.session_state['summary_df'].columns: |
|
summary_df = st.session_state['summary_df'] |
|
|
|
topic_names = {t: name for t, name in zip(summary_df['Topic'], summary_df['Cluster_Name'])} |
|
|
|
cluster_df['Cluster_Name'] = cluster_df['Topic'].map(lambda x: topic_names.get(x, 'Unnamed Cluster')) |
|
|
|
cluster_df = cluster_df[['Topic', 'Cluster_Name', 'Count', 'Top Keywords']] |
|
|
|
st.write("### Available Clusters:") |
|
st.dataframe( |
|
cluster_df, |
|
column_config={ |
|
"Topic": st.column_config.NumberColumn("Topic", help="Topic ID (-1 represents outliers)"), |
|
"Cluster_Name": st.column_config.TextColumn("Cluster Name", help="AI-generated name describing the cluster theme"), |
|
"Count": st.column_config.NumberColumn("Count", help="Number of documents in this topic"), |
|
"Top Keywords": st.column_config.TextColumn( |
|
"Top Keywords", |
|
help="Top 5 keywords that characterize this topic" |
|
) |
|
}, |
|
hide_index=True |
|
) |
|
|
|
|
|
st.subheader("Summarization Settings") |
|
|
|
summary_scope = st.radio( |
|
"Generate summaries for:", |
|
["All clusters", "Specific clusters"] |
|
) |
|
if summary_scope == "Specific clusters": |
|
|
|
if 'Cluster_Name' in cluster_df.columns: |
|
topic_options = [f"Cluster {t} - {name}" for t, name in zip(cluster_df['Topic'], cluster_df['Cluster_Name'])] |
|
topic_to_id = {opt: t for opt, t in zip(topic_options, cluster_df['Topic'])} |
|
selected_topic_options = st.multiselect("Select clusters to summarize", topic_options) |
|
selected_topics = [topic_to_id[opt] for opt in selected_topic_options] |
|
else: |
|
selected_topics = st.multiselect("Select clusters to summarize", topics) |
|
else: |
|
selected_topics = topics |
|
|
|
|
|
default_system_prompt = """You are an expert summarizer skilled in creating concise and relevant summaries. |
|
You will be given text and an objective context. Please produce a clear, cohesive, |
|
and thematically relevant summary. |
|
Focus on key points, insights, or patterns that emerge from the text.""" |
|
|
|
if 'system_prompt' not in st.session_state: |
|
st.session_state['system_prompt'] = default_system_prompt |
|
|
|
with st.expander("π§ Advanced Settings", expanded=False): |
|
st.markdown(""" |
|
### System Prompt Configuration |
|
|
|
The system prompt guides the AI in how to generate summaries. You can customize it to better suit your needs: |
|
- Be specific about the style and focus you want |
|
- Add domain-specific context if needed |
|
- Include any special formatting requirements |
|
""") |
|
|
|
system_prompt = st.text_area( |
|
"Customize System Prompt", |
|
value=st.session_state['system_prompt'], |
|
height=150, |
|
help="This prompt guides the AI in how to generate summaries. Edit it to customize the summary style and focus." |
|
) |
|
|
|
if st.button("Reset to Default"): |
|
system_prompt = default_system_prompt |
|
st.session_state['system_prompt'] = default_system_prompt |
|
|
|
st.markdown("### Generation Parameters") |
|
temperature = st.slider( |
|
"Temperature", |
|
0.0, 1.0, 0.7, |
|
help="Higher values (0.7-1.0) make summaries more creative but less predictable. Lower values (0.1-0.3) make them more focused and consistent." |
|
) |
|
max_tokens = st.slider( |
|
"Max Tokens", |
|
100, 3000, 1000, |
|
help="Maximum length of generated summaries. Higher values allow for more detailed summaries but take longer to generate." |
|
) |
|
|
|
st.session_state['system_prompt'] = system_prompt |
|
|
|
st.write("### Enhanced Summary References") |
|
st.write("Select columns for references (optional).") |
|
all_cols = [c for c in df_summ.columns if c not in ['text', 'Topic', 'similarity_score']] |
|
|
|
|
|
if 'reference_id_column' not in st.session_state: |
|
st.session_state.reference_id_column = all_cols[0] if all_cols else None |
|
|
|
url_guess = next((c for c in all_cols if 'url' in c.lower() or 'link' in c.lower()), None) |
|
if 'url_column' not in st.session_state: |
|
st.session_state.url_column = url_guess |
|
|
|
enable_references = st.checkbox( |
|
"Enable references in summaries", |
|
value=True, |
|
help="Add source references to the final summary text." |
|
) |
|
reference_id_column = st.selectbox( |
|
"Select column to use as reference ID:", |
|
all_cols, |
|
index=all_cols.index(st.session_state.reference_id_column) if st.session_state.reference_id_column in all_cols else 0 |
|
) |
|
add_hyperlinks = st.checkbox( |
|
"Add hyperlinks to references", |
|
value=True, |
|
help="If the reference column has a matching URL, make it clickable." |
|
) |
|
url_column = None |
|
if add_hyperlinks: |
|
url_column = st.selectbox( |
|
"Select column containing URLs:", |
|
all_cols, |
|
index=all_cols.index(st.session_state.url_column) if (st.session_state.url_column in all_cols) else 0 |
|
) |
|
|
|
|
|
if st.button("Generate Summaries"): |
|
openai_api_key = os.environ.get('OPENAI_API_KEY') |
|
if not openai_api_key: |
|
st.error("OpenAI API key not found. Please set the OPENAI_API_KEY environment variable.") |
|
else: |
|
|
|
st.session_state['_summarization_button_clicked'] = True |
|
|
|
llm = ChatOpenAI( |
|
api_key=openai_api_key, |
|
model_name='gpt-4o-mini', |
|
temperature=temperature, |
|
max_tokens=max_tokens |
|
) |
|
|
|
|
|
if selected_topics: |
|
df_scope = df_summ[df_summ['Topic'].isin(selected_topics)] |
|
else: |
|
st.warning("No topics selected for summarization.") |
|
df_scope = pd.DataFrame() |
|
|
|
if df_scope.empty: |
|
st.warning("No documents match the selected topics for summarization.") |
|
else: |
|
all_texts = df_scope['text'].tolist() |
|
combined_text = " ".join(all_texts) |
|
if not combined_text.strip(): |
|
st.warning("No text data available for summarization.") |
|
else: |
|
|
|
local_system_message = SystemMessagePromptTemplate.from_template(st.session_state['system_prompt']) |
|
local_human_message = HumanMessagePromptTemplate.from_template("{user_prompt}") |
|
local_chat_prompt = ChatPromptTemplate.from_messages([local_system_message, local_human_message]) |
|
|
|
|
|
|
|
unique_selected_topics = df_scope['Topic'].unique() |
|
if len(unique_selected_topics) > 1: |
|
st.write("### Summaries per Selected Cluster") |
|
|
|
|
|
with st.spinner("Generating cluster summaries in parallel..."): |
|
summaries = process_summaries_in_parallel( |
|
df_scope=df_scope, |
|
unique_selected_topics=unique_selected_topics, |
|
llm=llm, |
|
chat_prompt=local_chat_prompt, |
|
enable_references=enable_references, |
|
reference_id_column=reference_id_column, |
|
url_column=url_column if add_hyperlinks else None, |
|
max_workers=min(16, len(unique_selected_topics)) |
|
) |
|
|
|
if summaries: |
|
summary_df = pd.DataFrame(summaries) |
|
|
|
st.session_state['summary_df'] = summary_df |
|
|
|
st.session_state['has_references'] = enable_references |
|
st.session_state['reference_id_column'] = reference_id_column |
|
st.session_state['url_column'] = url_column if add_hyperlinks else None |
|
|
|
|
|
if 'Cluster_Name' in summary_df.columns: |
|
topic_names = {t: name for t, name in zip(summary_df['Topic'], summary_df['Cluster_Name'])} |
|
cluster_df['Cluster_Name'] = cluster_df['Topic'].map(lambda x: topic_names.get(x, 'Unnamed Cluster')) |
|
cluster_df = cluster_df[['Topic', 'Cluster_Name', 'Count', 'Top Keywords']] |
|
|
|
|
|
st.write("### Updated Topic Overview:") |
|
st.dataframe( |
|
cluster_df, |
|
column_config={ |
|
"Topic": st.column_config.NumberColumn("Topic", help="Topic ID (-1 represents outliers)"), |
|
"Cluster_Name": st.column_config.TextColumn("Cluster Name", help="AI-generated name describing the cluster theme"), |
|
"Count": st.column_config.NumberColumn("Count", help="Number of documents in this topic"), |
|
"Top Keywords": st.column_config.TextColumn( |
|
"Top Keywords", |
|
help="Top 5 keywords that characterize this topic" |
|
) |
|
}, |
|
hide_index=True |
|
) |
|
|
|
|
|
with st.spinner("Generating high-level summary from cluster summaries..."): |
|
|
|
formatted_summaries = [] |
|
total_tokens = 0 |
|
MAX_SAFE_TOKENS = int(MAX_CONTEXT_WINDOW * 0.75) |
|
summary_batches = [] |
|
current_batch = [] |
|
current_batch_tokens = 0 |
|
|
|
for _, row in summary_df.iterrows(): |
|
summary_text = row.get('Enhanced_Summary', row['Summary']) |
|
formatted_summary = f"### Cluster {row['Topic']} Summary:\n\n{summary_text}" |
|
summary_tokens = len(tokenizer(formatted_summary)["input_ids"]) |
|
|
|
|
|
if current_batch_tokens + summary_tokens > MAX_SAFE_TOKENS: |
|
if current_batch: |
|
summary_batches.append(current_batch) |
|
current_batch = [] |
|
current_batch_tokens = 0 |
|
|
|
current_batch.append(formatted_summary) |
|
current_batch_tokens += summary_tokens |
|
|
|
|
|
if current_batch: |
|
summary_batches.append(current_batch) |
|
|
|
|
|
batch_overviews = [] |
|
with st.spinner("Generating batch summaries..."): |
|
for i, batch in enumerate(summary_batches, 1): |
|
st.write(f"Processing batch {i} of {len(summary_batches)}...") |
|
|
|
batch_text = "\n\n".join(batch) |
|
batch_prompt = f"""Below are summaries from a subset of clusters from results made using Transformers NLP on a set of results from the CGIAR reporting system. Each summary contains references to source documents in the form of hyperlinked IDs like [ID] or <a href="...">ID</a>. |
|
|
|
Please create a comprehensive overview that synthesizes these clusters so that both the main themes and findings are covered in an organized way. IMPORTANT: |
|
1. Preserve all hyperlinked references exactly as they appear in the input summaries |
|
2. Maintain the HTML anchor tags (<a href="...">) intact when using information from the summaries |
|
3. Keep the markdown formatting for better readability |
|
4. Note that this is part {i} of {len(summary_batches)} parts, so focus on the themes present in these specific clusters |
|
|
|
Here are the cluster summaries to synthesize: |
|
|
|
{batch_text}""" |
|
|
|
|
|
high_level_system_message = SystemMessagePromptTemplate.from_template(st.session_state['system_prompt']) |
|
high_level_human_message = HumanMessagePromptTemplate.from_template("{user_prompt}") |
|
high_level_chat_prompt = ChatPromptTemplate.from_messages([high_level_system_message, high_level_human_message]) |
|
high_level_chain = LLMChain(llm=llm, prompt=high_level_chat_prompt) |
|
batch_overview = high_level_chain.run(user_prompt=batch_prompt).strip() |
|
batch_overviews.append(batch_overview) |
|
|
|
|
|
with st.spinner("Generating final combined summary..."): |
|
combined_overviews = "\n\n### Part ".join([f"{i+1}:\n\n{overview}" for i, overview in enumerate(batch_overviews)]) |
|
final_prompt = f"""Below are {len(batch_overviews)} overview summaries, each covering different clusters of research results. Each part maintains its original references to source documents. |
|
|
|
Please create a final comprehensive synthesis that: |
|
1. Integrates the key themes and findings from all parts |
|
2. Preserves all hyperlinked references exactly as they appear |
|
3. Maintains the HTML anchor tags (<a href="...">) intact |
|
4. Keeps the markdown formatting for better readability |
|
5. Creates a coherent narrative across all parts |
|
6. Highlights any themes that span multiple parts |
|
|
|
Here are the overviews to synthesize: |
|
|
|
### Part 1: |
|
|
|
{combined_overviews}""" |
|
|
|
|
|
final_prompt_tokens = len(tokenizer(final_prompt)["input_ids"]) |
|
if final_prompt_tokens > MAX_SAFE_TOKENS: |
|
st.error(f"β Final synthesis prompt ({final_prompt_tokens:,} tokens) exceeds safe limit ({MAX_SAFE_TOKENS:,}). Using batch summaries separately.") |
|
high_level_summary = "# Overall Summary\n\n" + "\n\n".join([f"## Batch {i+1}\n\n{overview}" for i, overview in enumerate(batch_overviews)]) |
|
else: |
|
|
|
high_level_chain = LLMChain(llm=llm, prompt=high_level_chat_prompt) |
|
high_level_summary = high_level_chain.run(user_prompt=final_prompt).strip() |
|
|
|
|
|
st.session_state['high_level_summary'] = high_level_summary |
|
st.session_state['enhanced_summary'] = high_level_summary |
|
|
|
|
|
st.session_state['summarization_completed'] = True |
|
|
|
|
|
st.write("### High-Level Summary:") |
|
st.markdown(high_level_summary, unsafe_allow_html=True) |
|
|
|
|
|
st.write("### Cluster Summaries:") |
|
if enable_references and 'Enhanced_Summary' in summary_df.columns: |
|
for idx, row in summary_df.iterrows(): |
|
cluster_name = row.get('Cluster_Name', 'Unnamed Cluster') |
|
st.write(f"**Topic {row['Topic']} - {cluster_name}**") |
|
st.markdown(row['Enhanced_Summary'], unsafe_allow_html=True) |
|
st.write("---") |
|
with st.expander("View original summaries in table format"): |
|
display_df = summary_df[['Topic', 'Cluster_Name', 'Summary']] |
|
display_df.columns = ['Topic', 'Cluster Name', 'Summary'] |
|
st.dataframe(display_df, hide_index=True) |
|
else: |
|
st.write("### Summaries per Cluster:") |
|
if 'Cluster_Name' in summary_df.columns: |
|
display_df = summary_df[['Topic', 'Cluster_Name', 'Summary']] |
|
display_df.columns = ['Topic', 'Cluster Name', 'Summary'] |
|
st.dataframe(display_df, hide_index=True) |
|
else: |
|
st.dataframe(summary_df, hide_index=True) |
|
|
|
|
|
if 'Enhanced_Summary' in summary_df.columns: |
|
dl_df = summary_df[['Topic', 'Cluster_Name', 'Summary']] |
|
dl_df.columns = ['Topic', 'Cluster Name', 'Summary'] |
|
else: |
|
dl_df = summary_df |
|
csv_bytes = dl_df.to_csv(index=False).encode('utf-8') |
|
b64 = base64.b64encode(csv_bytes).decode() |
|
href = f'<a href="data:file/csv;base64,{b64}" download="summaries.csv">Download Summaries CSV</a>' |
|
st.markdown(href, unsafe_allow_html=True) |
|
|
|
|
|
if st.session_state.get('summarization_completed', False): |
|
if 'summary_df' in st.session_state and not st.session_state['summary_df'].empty: |
|
if 'high_level_summary' in st.session_state: |
|
st.write("### High-Level Summary:") |
|
st.markdown(st.session_state['enhanced_summary'] if st.session_state.get('enhanced_summary') else st.session_state['high_level_summary'], unsafe_allow_html=True) |
|
|
|
st.write("### Cluster Summaries:") |
|
summary_df = st.session_state['summary_df'] |
|
if 'Enhanced_Summary' in summary_df.columns: |
|
for idx, row in summary_df.iterrows(): |
|
cluster_name = row.get('Cluster_Name', 'Unnamed Cluster') |
|
st.write(f"**Topic {row['Topic']} - {cluster_name}**") |
|
st.markdown(row['Enhanced_Summary'], unsafe_allow_html=True) |
|
st.write("---") |
|
with st.expander("View original summaries in table format"): |
|
display_df = summary_df[['Topic', 'Cluster_Name', 'Summary']] |
|
display_df.columns = ['Topic', 'Cluster Name', 'Summary'] |
|
st.dataframe(display_df, hide_index=True) |
|
else: |
|
st.dataframe(summary_df, hide_index=True) |
|
|
|
|
|
dl_df = summary_df[['Topic', 'Cluster_Name', 'Summary']] if 'Cluster_Name' in summary_df.columns else summary_df |
|
if 'Cluster_Name' in dl_df.columns: |
|
dl_df.columns = ['Topic', 'Cluster Name', 'Summary'] |
|
csv_bytes = dl_df.to_csv(index=False).encode('utf-8') |
|
b64 = base64.b64encode(csv_bytes).decode() |
|
href = f'<a href="data:file/csv;base64,{b64}" download="summaries.csv">Download Summaries CSV</a>' |
|
st.markdown(href, unsafe_allow_html=True) |
|
else: |
|
st.warning("No data available for summarization.") |
|
|
|
|
|
if not st.session_state.get('_summarization_button_clicked', False): |
|
if 'high_level_summary' in st.session_state: |
|
st.write("### Existing High-Level Summary:") |
|
if st.session_state.get('enhanced_summary'): |
|
st.markdown(st.session_state['enhanced_summary'], unsafe_allow_html=True) |
|
with st.expander("View original summary (without references)"): |
|
st.write(st.session_state['high_level_summary']) |
|
else: |
|
st.write(st.session_state['high_level_summary']) |
|
|
|
if 'summary_df' in st.session_state and not st.session_state['summary_df'].empty: |
|
st.write("### Existing Cluster Summaries:") |
|
summary_df = st.session_state['summary_df'] |
|
if 'Enhanced_Summary' in summary_df.columns: |
|
for idx, row in summary_df.iterrows(): |
|
cluster_name = row.get('Cluster_Name', 'Unnamed Cluster') |
|
st.write(f"**Topic {row['Topic']} - {cluster_name}**") |
|
st.markdown(row['Enhanced_Summary'], unsafe_allow_html=True) |
|
st.write("---") |
|
with st.expander("View original summaries in table format"): |
|
display_df = summary_df[['Topic', 'Cluster_Name', 'Summary']] |
|
display_df.columns = ['Topic', 'Cluster Name', 'Summary'] |
|
st.dataframe(display_df, hide_index=True) |
|
else: |
|
st.dataframe(summary_df, hide_index=True) |
|
|
|
|
|
dl_df = summary_df[['Topic', 'Cluster_Name', 'Summary']] if 'Cluster_Name' in summary_df.columns else summary_df |
|
if 'Cluster_Name' in dl_df.columns: |
|
dl_df.columns = ['Topic', 'Cluster Name', 'Summary'] |
|
csv_bytes = dl_df.to_csv(index=False).encode('utf-8') |
|
b64 = base64.b64encode(csv_bytes).decode() |
|
href = f'<a href="data:file/csv;base64,{b64}" download="summaries.csv">Download Summaries CSV</a>' |
|
st.markdown(href, unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
|
|
|
with tab_chat: |
|
st.header("Chat with Your Data") |
|
|
|
|
|
with st.expander("βΉοΈ How Chat Works", expanded=False): |
|
st.markdown(""" |
|
### Understanding Chat with Your Data |
|
|
|
The chat functionality allows you to have an interactive conversation about your data, whether it's filtered, clustered, or raw. Here's how it works: |
|
|
|
1. **Data Selection**: |
|
- Choose which dataset to chat about (filtered, clustered, or search results) |
|
- Optionally focus on specific clusters if clustering was performed |
|
- System automatically includes relevant context from your selection |
|
|
|
2. **Context Window**: |
|
- Shows how much of the GPT-4 context window is being used |
|
- Helps you understand if you need to filter data further |
|
- Displays token usage statistics |
|
|
|
3. **Chat Features**: |
|
- Ask questions about your data |
|
- Get insights and analysis |
|
- Reference specific documents or clusters |
|
- Download chat context for transparency |
|
|
|
### Best Practices |
|
|
|
1. **Data Selection**: |
|
- Start with filtered or clustered data for more focused conversations |
|
- Select specific clusters if you want to dive deep into a topic |
|
- Consider the context window usage when selecting data |
|
|
|
2. **Asking Questions**: |
|
- Be specific in your questions |
|
- Ask about patterns, trends, or insights |
|
- Reference clusters or documents by their IDs |
|
- Build on previous questions for deeper analysis |
|
|
|
3. **Managing Context**: |
|
- Monitor the context window usage |
|
- Filter data further if context is too full |
|
- Download chat context for documentation |
|
- Clear chat history to start fresh |
|
|
|
### Tips for Better Results |
|
|
|
- **Question Types**: |
|
- "What are the main themes in cluster 3?" |
|
- "Compare the findings between clusters 1 and 2" |
|
- "Summarize the methodology used across these documents" |
|
- "What are the common outcomes reported?" |
|
|
|
- **Follow-up Questions**: |
|
- Build on previous answers |
|
- Ask for clarification |
|
- Request specific examples |
|
- Explore relationships between findings |
|
""") |
|
|
|
|
|
def get_available_data_sources(): |
|
sources = [] |
|
if 'filtered_df' in st.session_state and not st.session_state['filtered_df'].empty: |
|
sources.append("Filtered Dataset") |
|
if 'clustered_data' in st.session_state and not st.session_state['clustered_data'].empty: |
|
sources.append("Clustered Data") |
|
if 'search_results' in st.session_state and not st.session_state['search_results'].empty: |
|
sources.append("Search Results") |
|
if ('high_level_summary' in st.session_state or |
|
('summary_df' in st.session_state and not st.session_state['summary_df'].empty)): |
|
sources.append("Summarized Data") |
|
return sources |
|
|
|
|
|
available_sources = get_available_data_sources() |
|
|
|
if not available_sources: |
|
st.warning("No data available for chat. Please filter, cluster, search, or summarize first.") |
|
st.stop() |
|
|
|
|
|
if 'chat_data_source' not in st.session_state: |
|
st.session_state.chat_data_source = available_sources[0] |
|
elif st.session_state.chat_data_source not in available_sources: |
|
st.session_state.chat_data_source = available_sources[0] |
|
|
|
|
|
data_source = st.radio( |
|
"Select data to chat about:", |
|
available_sources, |
|
index=available_sources.index(st.session_state.chat_data_source), |
|
help="Choose which dataset you want to analyze in the chat." |
|
) |
|
|
|
|
|
if data_source != st.session_state.chat_data_source: |
|
st.session_state.chat_data_source = data_source |
|
|
|
if 'chat_selected_cluster' in st.session_state: |
|
del st.session_state.chat_selected_cluster |
|
|
|
|
|
df_chat = None |
|
if data_source == "Filtered Dataset": |
|
df_chat = st.session_state['filtered_df'] |
|
elif data_source == "Clustered Data": |
|
df_chat = st.session_state['clustered_data'] |
|
elif data_source == "Search Results": |
|
df_chat = st.session_state['search_results'] |
|
elif data_source == "Summarized Data": |
|
|
|
summary_rows = [] |
|
|
|
|
|
if 'high_level_summary' in st.session_state: |
|
summary_rows.append({ |
|
'Summary_Type': 'High-Level Summary', |
|
'Content': st.session_state.get('enhanced_summary', st.session_state['high_level_summary']) |
|
}) |
|
|
|
|
|
if 'summary_df' in st.session_state and not st.session_state['summary_df'].empty: |
|
summary_df = st.session_state['summary_df'] |
|
for _, row in summary_df.iterrows(): |
|
summary_rows.append({ |
|
'Summary_Type': f"Cluster {row['Topic']} Summary", |
|
'Content': row.get('Enhanced_Summary', row['Summary']) |
|
}) |
|
|
|
if summary_rows: |
|
df_chat = pd.DataFrame(summary_rows) |
|
|
|
if df_chat is not None and not df_chat.empty: |
|
|
|
selected_cluster = None |
|
if data_source != "Summarized Data" and 'Topic' in df_chat.columns: |
|
cluster_option = st.radio( |
|
"Choose cluster scope:", |
|
["All Clusters", "Specific Cluster"] |
|
) |
|
if cluster_option == "Specific Cluster": |
|
unique_topics = sorted(df_chat['Topic'].unique()) |
|
|
|
if 'summary_df' in st.session_state and 'Cluster_Name' in st.session_state['summary_df'].columns: |
|
summary_df = st.session_state['summary_df'] |
|
|
|
topic_names = {t: name for t, name in zip(summary_df['Topic'], summary_df['Cluster_Name'])} |
|
|
|
topic_options = [ |
|
(t, f"Cluster {t} - {topic_names.get(t, 'Unnamed Cluster')}") |
|
for t in unique_topics |
|
] |
|
selected_cluster = st.selectbox( |
|
"Select cluster to focus on:", |
|
[t[0] for t in topic_options], |
|
format_func=lambda x: next(opt[1] for opt in topic_options if opt[0] == x) |
|
) |
|
else: |
|
selected_cluster = st.selectbox( |
|
"Select cluster to focus on:", |
|
unique_topics, |
|
format_func=lambda x: f"Cluster {x}" |
|
) |
|
if selected_cluster is not None: |
|
df_chat = df_chat[df_chat['Topic'] == selected_cluster] |
|
st.session_state.chat_selected_cluster = selected_cluster |
|
elif 'chat_selected_cluster' in st.session_state: |
|
del st.session_state.chat_selected_cluster |
|
|
|
|
|
text_columns = st.session_state.get('text_columns', []) |
|
if not text_columns and data_source != "Summarized Data": |
|
st.warning("No text columns selected. Please select text columns to enable chat functionality.") |
|
st.stop() |
|
|
|
|
|
MAX_ALLOWED_TOKENS = int(MAX_CONTEXT_WINDOW * 0.95) |
|
|
|
|
|
system_msg = { |
|
"role": "system", |
|
"content": """You are a specialized assistant analyzing data from a research database. |
|
Your role is to: |
|
1. Provide clear, concise answers based on the data provided |
|
2. Highlight relevant information from specific results when answering |
|
3. When referencing specific results, use their row index or ID if available |
|
4. Clearly state if information is not available in the results |
|
5. Maintain a professional and analytical tone |
|
6. Format your responses using Markdown: |
|
- Use **bold** for emphasis |
|
- Use bullet points and numbered lists for structured information |
|
- Create tables using Markdown syntax when presenting structured data |
|
- Use backticks for code or technical terms |
|
- Include hyperlinks when referencing external sources |
|
- Use headings (###) to organize long responses |
|
|
|
The data is provided in a structured format where:""" + (""" |
|
- Each result contains multiple fields |
|
- Text content is primarily in the following columns: """ + ", ".join(text_columns) + """ |
|
- Additional metadata and fields are available for reference |
|
- If clusters are present, they are numbered (e.g., Cluster 0, Cluster 1, etc.)""" if data_source != "Summarized Data" else """ |
|
- The data consists of AI-generated summaries of the documents |
|
- Each summary may contain references to source documents in markdown format |
|
- References are shown as [ID] or as clickable hyperlinks |
|
- Summaries may be high-level (covering all documents) or cluster-specific""") + """ |
|
""" |
|
} |
|
|
|
|
|
system_tokens = len(tokenizer(system_msg["content"])["input_ids"]) |
|
remaining_tokens = MAX_ALLOWED_TOKENS - system_tokens |
|
|
|
|
|
data_text = "Available Data:\n" |
|
included_rows = 0 |
|
total_rows = len(df_chat) |
|
|
|
if data_source == "Summarized Data": |
|
|
|
for idx, row in df_chat.iterrows(): |
|
row_text = f"\n{row['Summary_Type']}:\n{row['Content']}\n" |
|
row_tokens = len(tokenizer(row_text)["input_ids"]) |
|
|
|
if remaining_tokens - row_tokens > 0: |
|
data_text += row_text |
|
remaining_tokens -= row_tokens |
|
included_rows += 1 |
|
else: |
|
break |
|
else: |
|
|
|
for idx, row in df_chat.iterrows(): |
|
row_text = f"\nItem {idx}:\n" |
|
for col in df_chat.columns: |
|
if not pd.isna(row[col]) and str(row[col]).strip() and col != 'similarity_score': |
|
row_text += f"{col}: {row[col]}\n" |
|
|
|
row_tokens = len(tokenizer(row_text)["input_ids"]) |
|
if remaining_tokens - row_tokens > 0: |
|
data_text += row_text |
|
remaining_tokens -= row_tokens |
|
included_rows += 1 |
|
else: |
|
break |
|
|
|
|
|
data_tokens = len(tokenizer(data_text)["input_ids"]) |
|
total_tokens = system_tokens + data_tokens |
|
context_usage_percent = (total_tokens / MAX_CONTEXT_WINDOW) * 100 |
|
|
|
|
|
st.subheader("Context Window Usage") |
|
st.write(f"System Message: {system_tokens:,} tokens") |
|
st.write(f"Data Context: {data_tokens:,} tokens") |
|
st.write(f"Total: {total_tokens:,} tokens ({context_usage_percent:.1f}% of available context)") |
|
st.write(f"Documents included: {included_rows:,} out of {total_rows:,} ({(included_rows/total_rows*100):.1f}%)") |
|
|
|
if context_usage_percent > 90: |
|
st.warning("β οΈ High context usage! Consider reducing the number of results or filtering further.") |
|
elif context_usage_percent > 75: |
|
st.info("βΉοΈ Moderate context usage. Still room for your question, but consider reducing results if asking a long question.") |
|
|
|
|
|
chat_context = f"""System Message: |
|
{system_msg['content']} |
|
|
|
{data_text}""" |
|
st.download_button( |
|
label="π₯ Download Chat Context", |
|
data=chat_context, |
|
file_name="chat_context.txt", |
|
mime="text/plain", |
|
help="Download the exact context that the chatbot receives" |
|
) |
|
|
|
|
|
col_chat1, col_chat2 = st.columns([3, 1]) |
|
with col_chat1: |
|
user_input = st.text_area("Ask a question about your data:", key="chat_input") |
|
with col_chat2: |
|
if st.button("Clear Chat History"): |
|
st.session_state.chat_history = [] |
|
st.rerun() |
|
|
|
|
|
current_tab = tabs_titles.index("Chat") |
|
|
|
if st.button("Send", key="send_button"): |
|
if user_input: |
|
|
|
st.session_state.active_tab_index = current_tab |
|
|
|
with st.spinner("Processing your question..."): |
|
|
|
st.session_state.chat_history.append({"role": "user", "content": user_input}) |
|
|
|
|
|
messages = [system_msg] |
|
messages.append({"role": "user", "content": f"Here is the data to reference:\n\n{data_text}\n\nUser question: {user_input}"}) |
|
|
|
|
|
response = get_chat_response(messages) |
|
|
|
if response: |
|
st.session_state.chat_history.append({"role": "assistant", "content": response}) |
|
|
|
|
|
st.subheader("Chat History") |
|
for message in st.session_state.chat_history: |
|
if message["role"] == "user": |
|
st.write("**You:**", message["content"]) |
|
else: |
|
st.write("**Assistant:**") |
|
st.markdown(message["content"], unsafe_allow_html=True) |
|
st.write("---") |
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
st.header("Automatic Mode") |
|
|
|
|
|
if 'df' not in st.session_state: |
|
default_dataset_path = os.path.join(BASE_DIR, 'input', 'export_data_table_results_20251203_101413CET.xlsx') |
|
df = load_default_dataset(default_dataset_path) |
|
if df is not None: |
|
st.session_state['df'] = df.copy() |
|
st.session_state['using_default_dataset'] = True |
|
st.session_state['filtered_df'] = df.copy() |
|
|
|
|
|
if 'text_columns' not in st.session_state or not st.session_state['text_columns']: |
|
default_text_cols = [] |
|
if 'Title' in df.columns and 'Description' in df.columns: |
|
default_text_cols = ['Title', 'Description'] |
|
st.session_state['text_columns'] = default_text_cols |
|
|
|
|
|
|
|
query = st.text_input("Write your query here:") |
|
|
|
|
|
|
|
|
|
if st.button("SNAP!"): |
|
if query.strip(): |
|
|
|
st.write("### Step 1: Semantic Search") |
|
with st.spinner("Performing Semantic Search..."): |
|
text_columns = st.session_state.get('text_columns', []) |
|
if text_columns: |
|
df_full = st.session_state['df'] |
|
embeddings, _ = load_or_compute_embeddings( |
|
df_full, |
|
st.session_state.get('using_default_dataset', False), |
|
st.session_state.get('uploaded_file_name'), |
|
text_columns |
|
) |
|
|
|
if embeddings is not None: |
|
model = get_embedding_model() |
|
df_filtered = st.session_state['filtered_df'].fillna("") |
|
search_texts = df_filtered[text_columns].agg(' '.join, axis=1).tolist() |
|
|
|
subset_indices = df_filtered.index |
|
subset_embeddings = embeddings[subset_indices] |
|
|
|
query_embedding = model.encode([query], device=device) |
|
similarities = cosine_similarity(query_embedding, subset_embeddings)[0] |
|
|
|
similarity_threshold = 0.35 |
|
above_threshold_indices = np.where(similarities > similarity_threshold)[0] |
|
|
|
if len(above_threshold_indices) > 0: |
|
selected_indices = subset_indices[above_threshold_indices] |
|
results = df_filtered.loc[selected_indices].copy() |
|
results['similarity_score'] = similarities[above_threshold_indices] |
|
results.sort_values(by='similarity_score', ascending=False, inplace=True) |
|
st.session_state['search_results'] = results.copy() |
|
st.write(f"Found {len(results)} relevant documents") |
|
else: |
|
st.warning("No results found above the similarity threshold.") |
|
st.stop() |
|
|
|
|
|
if 'search_results' in st.session_state and not st.session_state['search_results'].empty: |
|
st.write("### Step 2: Clustering") |
|
with st.spinner("Performing clustering..."): |
|
df_to_cluster = st.session_state['search_results'].copy() |
|
dfc = df_to_cluster.copy().fillna("") |
|
dfc['text'] = dfc[text_columns].astype(str).agg(' '.join, axis=1) |
|
|
|
|
|
selected_indices = dfc.index |
|
embeddings_clustering = embeddings[selected_indices] |
|
|
|
|
|
stop_words = set(stopwords.words('english')) |
|
texts_cleaned = [] |
|
for text in dfc['text'].tolist(): |
|
try: |
|
word_tokens = word_tokenize(text) |
|
filtered_text = ' '.join([w for w in word_tokens if w.lower() not in stop_words]) |
|
texts_cleaned.append(filtered_text) |
|
except Exception as e: |
|
texts_cleaned.append(text) |
|
|
|
min_cluster_size = 5 |
|
|
|
try: |
|
|
|
if torch.is_tensor(embeddings_clustering): |
|
embeddings_for_clustering = embeddings_clustering.cpu().numpy() |
|
else: |
|
embeddings_for_clustering = embeddings_clustering |
|
|
|
|
|
hdbscan_model = HDBSCAN( |
|
min_cluster_size=min_cluster_size, |
|
metric='euclidean', |
|
cluster_selection_method='eom' |
|
) |
|
|
|
|
|
topic_model = BERTopic( |
|
embedding_model=get_embedding_model(), |
|
hdbscan_model=hdbscan_model |
|
) |
|
|
|
|
|
topics, probs = topic_model.fit_transform( |
|
texts_cleaned, |
|
embeddings=embeddings_for_clustering |
|
) |
|
|
|
|
|
dfc['Topic'] = topics |
|
st.session_state['topic_model'] = topic_model |
|
st.session_state['clustered_data'] = dfc.copy() |
|
st.session_state['clustering_completed'] = True |
|
|
|
|
|
unique_topics = sorted(list(set(topics))) |
|
num_clusters = len([t for t in unique_topics if t != -1]) |
|
noise_docs = len([t for t in topics if t == -1]) |
|
clustered_docs = len(topics) - noise_docs |
|
|
|
st.write(f"Found {num_clusters} distinct clusters") |
|
|
|
|
|
|
|
|
|
|
|
cluster_info = [] |
|
for t in unique_topics: |
|
if t != -1: |
|
cluster_docs = dfc[dfc['Topic'] == t] |
|
count = len(cluster_docs) |
|
top_words = topic_model.get_topic(t) |
|
top_keywords = ", ".join([w[0] for w in top_words[:5]]) if top_words else "N/A" |
|
cluster_info.append((t, count, top_keywords)) |
|
|
|
if cluster_info: |
|
|
|
cluster_df = pd.DataFrame(cluster_info, columns=["Topic", "Count", "Top Keywords"]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
st.session_state['intertopic_distance_fig'] = topic_model.visualize_topics() |
|
except Exception: |
|
st.session_state['intertopic_distance_fig'] = None |
|
|
|
try: |
|
st.session_state['topic_document_fig'] = topic_model.visualize_documents( |
|
texts_cleaned, |
|
embeddings=embeddings_for_clustering |
|
) |
|
except Exception: |
|
st.session_state['topic_document_fig'] = None |
|
|
|
try: |
|
hierarchy = topic_model.hierarchical_topics(texts_cleaned) |
|
st.session_state['hierarchy'] = hierarchy if hierarchy is not None else pd.DataFrame() |
|
st.session_state['hierarchy_fig'] = topic_model.visualize_hierarchy() |
|
except Exception: |
|
st.session_state['hierarchy'] = pd.DataFrame() |
|
st.session_state['hierarchy_fig'] = None |
|
|
|
except Exception as e: |
|
st.error(f"An error occurred during clustering: {str(e)}") |
|
st.stop() |
|
|
|
|
|
if st.session_state.get('clustering_completed', False): |
|
st.write("### Step 3: Summarization") |
|
|
|
|
|
openai_api_key = os.environ.get('OPENAI_API_KEY') |
|
if not openai_api_key: |
|
st.error("OpenAI API key not found. Please set the OPENAI_API_KEY environment variable.") |
|
st.stop() |
|
|
|
llm = ChatOpenAI( |
|
api_key=openai_api_key, |
|
model_name='gpt-4o-mini', |
|
temperature=0.7, |
|
max_tokens=1000 |
|
) |
|
|
|
df_scope = st.session_state['clustered_data'] |
|
unique_selected_topics = df_scope['Topic'].unique() |
|
|
|
|
|
with st.spinner("Generating summaries..."): |
|
local_system_message = SystemMessagePromptTemplate.from_template("""You are an expert summarizer skilled in creating concise and relevant summaries. |
|
You will be given text and an objective context. Please produce a clear, cohesive, |
|
and thematically relevant summary. |
|
Focus on key points, insights, or patterns that emerge from the text.""") |
|
local_human_message = HumanMessagePromptTemplate.from_template("{user_prompt}") |
|
local_chat_prompt = ChatPromptTemplate.from_messages([local_system_message, local_human_message]) |
|
|
|
|
|
url_column = next((col for col in df_scope.columns if 'url' in col.lower() or 'link' in col.lower() or 'pdf' in col.lower()), None) |
|
|
|
summaries = process_summaries_in_parallel( |
|
df_scope=df_scope, |
|
unique_selected_topics=unique_selected_topics, |
|
llm=llm, |
|
chat_prompt=local_chat_prompt, |
|
enable_references=True, |
|
reference_id_column=df_scope.columns[0], |
|
url_column=url_column, |
|
max_workers=min(16, len(unique_selected_topics)) |
|
) |
|
|
|
if summaries: |
|
summary_df = pd.DataFrame(summaries) |
|
st.session_state['summary_df'] = summary_df |
|
|
|
|
|
if 'Cluster_Name' in summary_df.columns: |
|
st.write("### Updated Topic Overview:") |
|
cluster_info = [] |
|
for t in unique_selected_topics: |
|
cluster_docs = df_scope[df_scope['Topic'] == t] |
|
count = len(cluster_docs) |
|
top_words = topic_model.get_topic(t) |
|
top_keywords = ", ".join([w[0] for w in top_words[:5]]) if top_words else "N/A" |
|
cluster_name = summary_df[summary_df['Topic'] == t]['Cluster_Name'].iloc[0] |
|
cluster_info.append((t, cluster_name, count, top_keywords)) |
|
|
|
cluster_df = pd.DataFrame(cluster_info, columns=["Topic", "Cluster_Name", "Count", "Top Keywords"]) |
|
st.dataframe( |
|
cluster_df, |
|
column_config={ |
|
"Topic": st.column_config.NumberColumn("Topic", help="Topic ID (-1 represents outliers)"), |
|
"Cluster_Name": st.column_config.TextColumn("Cluster Name", help="AI-generated name describing the cluster theme"), |
|
"Count": st.column_config.NumberColumn("Count", help="Number of documents in this topic"), |
|
"Top Keywords": st.column_config.TextColumn( |
|
"Top Keywords", |
|
help="Top 5 keywords that characterize this topic" |
|
) |
|
}, |
|
hide_index=True |
|
) |
|
|
|
|
|
with st.spinner("Generating high-level summary..."): |
|
formatted_summaries = [] |
|
summary_batches = [] |
|
current_batch = [] |
|
current_batch_tokens = 0 |
|
MAX_SAFE_TOKENS = int(MAX_CONTEXT_WINDOW * 0.75) |
|
|
|
for _, row in summary_df.iterrows(): |
|
summary_text = row.get('Enhanced_Summary', row['Summary']) |
|
formatted_summary = f"### Cluster {row['Topic']} Summary:\n\n{summary_text}" |
|
summary_tokens = len(tokenizer(formatted_summary)["input_ids"]) |
|
|
|
if current_batch_tokens + summary_tokens > MAX_SAFE_TOKENS: |
|
if current_batch: |
|
summary_batches.append(current_batch) |
|
current_batch = [] |
|
current_batch_tokens = 0 |
|
|
|
current_batch.append(formatted_summary) |
|
current_batch_tokens += summary_tokens |
|
|
|
if current_batch: |
|
summary_batches.append(current_batch) |
|
|
|
|
|
batch_overviews = [] |
|
for i, batch in enumerate(summary_batches, 1): |
|
st.write(f"Processing summary batch {i} of {len(summary_batches)}...") |
|
batch_text = "\n\n".join(batch) |
|
batch_prompt = f"""Below are summaries from a subset of clusters from results made using Transformers NLP on a set of results from the CGIAR reporting system. Each summary contains references to source documents in the form of hyperlinked IDs like [ID] or <a href="...">ID</a>. |
|
|
|
Please create a comprehensive overview that synthesizes these clusters so that both the main themes and findings are covered in an organized way. IMPORTANT: |
|
1. Preserve all hyperlinked references exactly as they appear in the input summaries |
|
2. Maintain the HTML anchor tags (<a href="...">) intact when using information from the summaries |
|
3. Keep the markdown formatting for better readability |
|
4. Create clear sections with headings for different themes |
|
5. Use bullet points or numbered lists where appropriate |
|
6. Focus on synthesizing the main themes and findings |
|
|
|
Here are the cluster summaries to synthesize: |
|
|
|
{batch_text}""" |
|
|
|
high_level_chain = LLMChain(llm=llm, prompt=local_chat_prompt) |
|
batch_overview = high_level_chain.run(user_prompt=batch_prompt).strip() |
|
batch_overviews.append(batch_overview) |
|
|
|
|
|
if len(batch_overviews) > 1: |
|
st.write("Generating final synthesis...") |
|
combined_overviews = "\n\n# Part ".join([f"{i+1}\n\n{overview}" for i, overview in enumerate(batch_overviews)]) |
|
final_prompt = f"""Below are multiple overview summaries, each covering different aspects of CGIAR research results. Each part maintains its original references to source documents. |
|
|
|
Please create a final comprehensive synthesis that: |
|
1. Integrates the key themes and findings from all parts into a cohesive narrative |
|
2. Preserves all hyperlinked references exactly as they appear |
|
3. Maintains the HTML anchor tags (<a href="...">) intact |
|
4. Uses clear section headings and structured formatting |
|
5. Highlights cross-cutting themes and relationships between different aspects |
|
6. Provides a clear introduction and conclusion |
|
|
|
Here are the overviews to synthesize: |
|
|
|
# Part 1 |
|
|
|
{combined_overviews}""" |
|
|
|
final_prompt_tokens = len(tokenizer(final_prompt)["input_ids"]) |
|
if final_prompt_tokens > MAX_SAFE_TOKENS: |
|
|
|
high_level_summary = "# Comprehensive Overview\n\n" + "\n\n# Part ".join([f"{i+1}\n\n{overview}" for i, overview in enumerate(batch_overviews)]) |
|
else: |
|
high_level_chain = LLMChain(llm=llm, prompt=local_chat_prompt) |
|
high_level_summary = high_level_chain.run(user_prompt=final_prompt).strip() |
|
else: |
|
|
|
high_level_summary = batch_overviews[0] |
|
|
|
st.session_state['high_level_summary'] = high_level_summary |
|
st.session_state['enhanced_summary'] = high_level_summary |
|
|
|
|
|
st.write("### High-Level Summary:") |
|
with st.expander("High-Level Summary", expanded=True): |
|
st.markdown(high_level_summary, unsafe_allow_html=True) |
|
|
|
st.write("### Cluster Summaries:") |
|
for idx, row in summary_df.iterrows(): |
|
cluster_name = row.get('Cluster_Name', 'Unnamed Cluster') |
|
with st.expander(f"Topic {row['Topic']} - {cluster_name}", expanded=False): |
|
st.markdown(row.get('Enhanced_Summary', row['Summary']), unsafe_allow_html=True) |
|
st.markdown("##### About this tool") |
|
with st.expander("Click to expand/collapse", expanded=True): |
|
st.markdown(""" |
|
This tool draws on CGIAR quality assured results data from 2022-2024 to provide verifiable responses to user questions around the themes and areas CGIAR has/is working on. |
|
|
|
**Tips:** |
|
- **Craft a phrase** that describes your topic of interest (e.g., `"climate-smart agriculture"`, `"gender equality livestock"`). |
|
- Avoid writing full questions β **this is not a chatbot**. |
|
- Combine **related terms** for better results (e.g., `"irrigation water access smallholders"`). |
|
- Focus on **concepts or themes** β not single words like `"climate"` or `"yield"` alone. |
|
- Example good queries: |
|
- `"climate adaptation smallholder farming"` |
|
- `"digital agriculture innovations"` |
|
- `"nutrition-sensitive value chains"` |
|
|
|
**Example use case**: |
|
You're interested in CGIAR's contributions to **poverty reduction through improved maize varieties in Africa**. |
|
A good search phrase would be: |
|
π `"poverty reduction maize Africa"` |
|
This will retrieve results related to improved crop varieties, livelihood outcomes, and region-specific interventions, even if the documents use different wording like *"enhanced maize genetics"*, *"smallholder income"*, or *"eastern Africa trials"*. |
|
""") |