|
|
import os |
|
|
import time |
|
|
import streamlit as st |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import re |
|
|
import string |
|
|
import json |
|
|
from io import BytesIO |
|
|
|
|
|
|
|
|
import plotly.express as px |
|
|
import plotly.graph_objects as go |
|
|
import plotly.io as pio |
|
|
from pptx import Presentation |
|
|
from pptx.util import Inches, Pt |
|
|
|
|
|
|
|
|
from gliner import GLiNER |
|
|
from sklearn.feature_extraction.text import TfidfVectorizer |
|
|
from sklearn.decomposition import LatentDirichletAllocation |
|
|
|
|
|
|
|
|
os.environ['HF_HOME'] = '/tmp' |
|
|
|
|
|
entity_color_map = { |
|
|
"person": "#10b981", "country": "#3b82f6", "city": "#4ade80", |
|
|
"organization": "#f59e0b", "date": "#8b5cf6", "time": "#ec4899", |
|
|
"cardinal": "#06b6d4", "money": "#f43f5e", "position": "#a855f7" |
|
|
} |
|
|
|
|
|
labels = list(entity_color_map.keys()) |
|
|
category_mapping = { |
|
|
"People": ["person", "organization", "position"], |
|
|
"Locations": ["country", "city"], |
|
|
"Time": ["date", "time"], |
|
|
"Numbers": ["money", "cardinal"] |
|
|
} |
|
|
reverse_category_mapping = {label: cat for cat, lbls in category_mapping.items() for label in lbls} |
|
|
|
|
|
|
|
|
|
|
|
def remove_trailing_punctuation(text_string): |
|
|
return text_string.rstrip(string.punctuation) |
|
|
|
|
|
def highlight_entities(text, df_entities): |
|
|
if df_entities.empty: |
|
|
return text |
|
|
|
|
|
entities = df_entities.sort_values(by='start', ascending=False).to_dict('records') |
|
|
highlighted_text = text |
|
|
for entity in entities: |
|
|
start, end = entity['start'], entity['end'] |
|
|
label, entity_text = entity['label'], entity['text'] |
|
|
color = entity_color_map.get(label, '#000000') |
|
|
highlight_html = f'<span style="background-color: {color}; color: white; padding: 2px 4px; border-radius: 3px; font-weight: bold;">{entity_text}</span>' |
|
|
highlighted_text = highlighted_text[:start] + highlight_html + highlighted_text[end:] |
|
|
return f'<div class="highlighted-text" style="border: 1px solid #ddd; padding: 15px; border-radius: 8px; background-color: #ffffff; line-height: 2; white-space: pre-wrap;">{highlighted_text}</div>' |
|
|
|
|
|
def perform_topic_modeling(df_entities, num_topics=2, num_top_words=10): |
|
|
documents = df_entities['text'].unique().tolist() |
|
|
if len(documents) < 2: return None |
|
|
try: |
|
|
tfidf_vectorizer = TfidfVectorizer(stop_words='english', ngram_range=(1, 3), min_df=1) |
|
|
tfidf = tfidf_vectorizer.fit_transform(documents) |
|
|
feature_names = tfidf_vectorizer.get_feature_names_out() |
|
|
lda = LatentDirichletAllocation(n_components=num_topics, random_state=42) |
|
|
lda.fit(tfidf) |
|
|
|
|
|
topic_data = [] |
|
|
for idx, topic in enumerate(lda.components_): |
|
|
top_indices = topic.argsort()[:-num_top_words - 1:-1] |
|
|
for i in top_indices: |
|
|
topic_data.append({'Topic_ID': f'Topic #{idx + 1}', 'Word': feature_names[i], 'Weight': topic[i]}) |
|
|
return pd.DataFrame(topic_data) |
|
|
except: return None |
|
|
|
|
|
|
|
|
|
|
|
def create_topic_word_bubbles(df_topic_data): |
|
|
df = df_topic_data.rename(columns={'Topic_ID': 'topic','Word': 'word', 'Weight': 'weight'}) |
|
|
df['x_pos'] = range(len(df)) |
|
|
fig = px.scatter(df, x='x_pos', y='weight', size='weight', color='topic', text='word', title='Topic Word Weights') |
|
|
|
|
|
fig.update_layout(margin=dict(t=80, b=50), xaxis_showticklabels=False, plot_bgcolor='#f9f9f9') |
|
|
fig.update_traces(textposition='middle center', textfont=dict(color='white', size=10)) |
|
|
return fig |
|
|
|
|
|
def generate_network_graph(df, raw_text): |
|
|
counts = df['text'].value_counts().reset_index(name='frequency') |
|
|
unique = df.drop_duplicates(subset=['text']).merge(counts, on='text') |
|
|
num_nodes = len(unique) |
|
|
thetas = np.linspace(0, 2 * np.pi, num_nodes, endpoint=False) |
|
|
unique['x'] = 10 * np.cos(thetas) |
|
|
unique['y'] = 10 * np.sin(thetas) |
|
|
|
|
|
fig = go.Figure() |
|
|
fig.add_trace(go.Scatter( |
|
|
x=unique['x'], y=unique['y'], mode='markers+text', text=unique['text'], |
|
|
marker=dict(size=unique['frequency']*5 + 15, color=[entity_color_map.get(l, '#ccc') for l in unique['label']]) |
|
|
)) |
|
|
|
|
|
fig.update_layout(title="Entity Relationship Map", margin=dict(t=80), showlegend=False, xaxis_visible=False, yaxis_visible=False) |
|
|
return fig |
|
|
|
|
|
|
|
|
|
|
|
def generate_html_report(df, text_input, elapsed_time, df_topic_data): |
|
|
|
|
|
fig_tree = px.treemap(df, path=[px.Constant("All"), 'category', 'label', 'text'], values='score', title="Entity Hierarchy") |
|
|
fig_tree.update_layout(margin=dict(t=60, b=20, l=20, r=20)) |
|
|
|
|
|
tree_html = fig_tree.to_html(full_html=False, include_plotlyjs='cdn') |
|
|
net_html = generate_network_graph(df, text_input).to_html(full_html=False, include_plotlyjs='cdn') |
|
|
|
|
|
html_template = f""" |
|
|
<html> |
|
|
<head> |
|
|
<script src="https://cdn.plot.ly/plotly-latest.min.js"></script> |
|
|
<style> |
|
|
body {{ font-family: sans-serif; background: #f4f7f6; padding: 30px; }} |
|
|
.card {{ background: white; padding: 25px; border-radius: 12px; margin-bottom: 25px; box-shadow: 0 2px 10px rgba(0,0,0,0.05); }} |
|
|
/* FIX: Critical for title visibility */ |
|
|
.chart-box {{ min-height: 500px; overflow: visible !important; border: 1px solid #eee; }} |
|
|
h1, h2 {{ color: #2c3e50; border-bottom: 2px solid #3498db; padding-bottom: 10px; }} |
|
|
</style> |
|
|
</head> |
|
|
<body> |
|
|
<div class="card"> |
|
|
<h1>NER & Topic Analysis Report</h1> |
|
|
<p>Processing Time: {elapsed_time:.2f}s</p> |
|
|
<h2>1. Highlighted Entities</h2> |
|
|
{highlight_entities(text_input, df)} |
|
|
<h2>2. Visual Analytics</h2> |
|
|
<div class="chart-box">{tree_html}</div> |
|
|
<div class="chart-box">{net_html}</div> |
|
|
</div> |
|
|
</body> |
|
|
</html> |
|
|
""" |
|
|
return html_template |
|
|
|
|
|
def generate_pptx_report(df): |
|
|
prs = Presentation() |
|
|
slide = prs.slides.add_slide(prs.slide_layouts[0]) |
|
|
slide.shapes.title.text = "Entity Analysis" |
|
|
slide = prs.slides.add_slide(prs.slide_layouts[1]) |
|
|
slide.shapes.title.text = "Entity List" |
|
|
tf = slide.placeholders[1].text_frame |
|
|
for i, row in df.head(15).iterrows(): |
|
|
p = tf.add_paragraph() |
|
|
p.text = f"{row['text']} ({row['label']})" |
|
|
buffer = BytesIO() |
|
|
prs.save(buffer) |
|
|
buffer.seek(0) |
|
|
return buffer |
|
|
|
|
|
|
|
|
|
|
|
st.set_page_config(layout="wide", page_title="DataHarvest NER") |
|
|
|
|
|
@st.cache_resource |
|
|
def load_model(): |
|
|
return GLiNER.from_pretrained("knowledgator/gliner-multitask-large-v0.5") |
|
|
|
|
|
model = load_model() |
|
|
|
|
|
|
|
|
if 'results_df' not in st.session_state: |
|
|
st.session_state.results_df = pd.DataFrame() |
|
|
st.session_state.show = False |
|
|
|
|
|
st.subheader("Entity & Topic Analysis Report Generator", divider="blue") |
|
|
|
|
|
text = st.text_area("Paste text here (max 1000 words):", height=250) |
|
|
|
|
|
if st.button("Run Analysis"): |
|
|
if text: |
|
|
with st.spinner("Processing..."): |
|
|
start = time.time() |
|
|
entities = model.predict_entities(text, labels) |
|
|
df = pd.DataFrame(entities) |
|
|
if not df.empty: |
|
|
df['text'] = df['text'].apply(remove_trailing_punctuation) |
|
|
df['category'] = df['label'].map(reverse_category_mapping) |
|
|
st.session_state.results_df = df |
|
|
st.session_state.elapsed = time.time() - start |
|
|
st.session_state.topics = perform_topic_modeling(df) |
|
|
st.session_state.show = True |
|
|
else: |
|
|
st.warning("No entities found.") |
|
|
|
|
|
if st.session_state.show: |
|
|
df = st.session_state.results_df |
|
|
|
|
|
st.markdown("### 1. Extracted Entities") |
|
|
st.markdown(highlight_entities(text, df), unsafe_allow_html=True) |
|
|
|
|
|
t1, t2, t3 = st.tabs(["Charts", "Network Map", "Topics"]) |
|
|
|
|
|
with t1: |
|
|
fig_tree = px.treemap(df, path=['category', 'label', 'text'], values='score', title="Entity Treemap") |
|
|
|
|
|
fig_tree.update_layout(margin=dict(t=50)) |
|
|
st.plotly_chart(fig_tree, use_container_width=True) |
|
|
|
|
|
with t2: |
|
|
st.plotly_chart(generate_network_graph(df, text), use_container_width=True) |
|
|
|
|
|
with t3: |
|
|
if st.session_state.topics is not None: |
|
|
st.plotly_chart(create_topic_word_bubbles(st.session_state.topics), use_container_width=True) |
|
|
else: |
|
|
st.info("Not enough data for topic modeling.") |
|
|
|
|
|
st.divider() |
|
|
st.markdown("### Download Artifacts") |
|
|
c1, c2, c3 = st.columns(3) |
|
|
|
|
|
with c1: |
|
|
st.download_button("Download HTML Report", |
|
|
generate_html_report(df, text, st.session_state.elapsed, st.session_state.topics), |
|
|
"report.html", "text/html", type="primary") |
|
|
with c2: |
|
|
csv = df.to_csv(index=False).encode('utf-8') |
|
|
st.download_button("Download CSV Data", csv, "entities.csv", "text/csv") |
|
|
with c3: |
|
|
st.download_button("Download PPTX Summary", generate_pptx_report(df), "summary.pptx") |