causev / plot.py
norygano's picture
Kolloquium
adb4a34
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import os
import umap
import streamlit as st
@st.cache_data
def load_data(file_path):
return pd.read_csv(file_path, sep='\t')
class Plot:
def __init__(self, data_file='data/feature_matrix.tsv', metadata_file='data/indicator_cause_sentence_metadata.tsv'):
self.data_file = data_file
self.metadata_file = metadata_file
self.df = load_data(self.data_file) # Cached data loading
self.metadata_df = load_data(self.metadata_file)
# Cache and compute necessary columns once
self.indicator_columns = [col for col in self.df.columns if col.startswith('indicator_')]
self.cause_columns = [col for col in self.df.columns if col.startswith('cause_')]
self.df['Year'] = self.df['text_date'].astype(str).str[:4]
self.df['Has_Indicator'] = self.df[self.indicator_columns].sum(axis=1) > 0
# Precompute totals for faster use in chart functions
self.total_sentences_per_year = self.df.groupby(['Year', 'subfolder']).size().reset_index(name='Total Sentences')
self.total_sentences_per_subfolder = self.df.groupby('subfolder').size().reset_index(name='Total Sentences')
def get_indicator_chart(self, chart_type='total', individual_threshold=5):
if chart_type == 'total':
# Summarize indicator share per subfolder
indicator_counts = self.df[self.df['Has_Indicator']].groupby('subfolder').size().reset_index(name='Indicator Count')
total_counts = indicator_counts.merge(self.total_sentences_per_subfolder, on='subfolder')
total_counts['Indicator_Share'] = total_counts['Indicator Count'] / total_counts['Total Sentences']
total_counts['Indicator_Share_Text'] = (total_counts['Indicator_Share'] * 100).round(2).astype(str) + '%'
fig = px.bar(
total_counts,
x='subfolder',
y='Indicator_Share',
labels={'Indicator_Share': 'Share of Sentences with Indicators', 'subfolder': ''},
color='subfolder',
text='Indicator_Share_Text',
color_discrete_sequence=px.colors.qualitative.D3
)
fig.update_traces(
textposition='inside',
insidetextanchor='middle',
texttemplate='%{text}',
textfont=dict(color='rgb(255, 255, 255)')
)
elif chart_type == 'individual':
# Melt the dataframe to long format
df_melted = self.df.melt(id_vars=['subfolder'], value_vars=self.indicator_columns, var_name='Indicator', value_name='Count')
df_melted = df_melted[df_melted['Count'] > 0]
# Group by Indicator only to calculate total counts across all subfolders
total_indicator_counts = df_melted.groupby('Indicator').size().reset_index(name='Total Count')
indicators_meeting_threshold = total_indicator_counts[total_indicator_counts['Total Count'] >= individual_threshold]['Indicator'].unique()
# Filter df_melted to include only indicators that meet the threshold overall
df_melted = df_melted[df_melted['Indicator'].isin(indicators_meeting_threshold)]
df_melted['Indicator'] = df_melted['Indicator'].str.replace('indicator_', '').str.capitalize()
# Re-aggregate counts by subfolder and indicator for the filtered indicators
df_melted = df_melted.groupby(['subfolder', 'Indicator']).size().reset_index(name='Count')
# Create the bar chart
fig = px.bar(
df_melted,
x='subfolder',
y='Count',
color='Indicator',
barmode='group',
labels={'Count': 'Occurrences', 'subfolder': '', 'Indicator': 'Indicator'},
color_discrete_sequence=px.colors.qualitative.D3
)
fig.update_traces(
texttemplate='%{y}',
textposition='inside',
insidetextanchor='middle',
textfont=dict(color='rgb(255, 255, 255)')
)
elif chart_type == 'year':
indicator_counts_per_year = self.df[self.df['Has_Indicator']].groupby(['Year', 'subfolder']).size().reset_index(name='Indicator Count')
df_summary = pd.merge(self.total_sentences_per_year, indicator_counts_per_year, on=['Year', 'subfolder'], how='left')
df_summary['Indicator_Share_Text'] = (df_summary['Indicator Count'] / df_summary['Total Sentences'] * 100).round(2).astype(str) + '%'
fig = px.bar(
df_summary,
x='Year',
y='Total Sentences',
color='subfolder',
labels={'Total Sentences': 'Total Number of Sentences', 'Year': 'Year'},
text='Indicator_Share_Text',
color_discrete_sequence=px.colors.qualitative.D3
)
fig.update_traces(
textposition='inside',
texttemplate='%{text}',
insidetextanchor='middle',
textfont=dict(color='rgb(255, 255, 255)')
)
fig.update_layout(
xaxis=dict(showline=True),
yaxis=dict(title='Indicator Sentences' if chart_type != 'year' else 'Total Sentences'),
bargap=0.05,
showlegend=(chart_type != 'total')
)
return fig
def get_causes_chart(self, min_value=30):
df_filtered = self.metadata_df[self.metadata_df['cause'] != 'N/A']
causes_meeting_threshold = df_filtered.groupby('cause')['cause'].count()[lambda x: x >= min_value].index
df_filtered = df_filtered[df_filtered['cause'].isin(causes_meeting_threshold)]
df_filtered['cause'] = df_filtered['cause'].str.capitalize()
fig = px.bar(
df_filtered.groupby(['subfolder', 'cause']).size().reset_index(name='Count'),
x='subfolder',
y='Count',
color='cause',
barmode='group',
labels={'Count': 'Occurrences', 'subfolder': '', 'cause': 'Cause'},
color_discrete_sequence=px.colors.qualitative.D3
)
fig.update_layout(xaxis=dict(showline=True), yaxis=dict(showticklabels=True, title=''))
fig.update_traces(
texttemplate='%{y}',
textposition='inside',
insidetextanchor='middle',
textfont=dict(color='rgb(255, 255, 255)')
)
return fig
def scatter(self, include_modality=False):
# Use self.df to avoid reloading data
df_filtered = self.df[(self.df[self.indicator_columns].sum(axis=1) > 0) |
(self.df[self.cause_columns].sum(axis=1) > 0)]
# Exclude specific indicators and filter based on count threshold
indicator_columns = [col for col in self.indicator_columns if 'indicator_!besprechen' not in col]
indicator_counts = df_filtered[indicator_columns].sum()
indicators_to_keep = indicator_counts[indicator_counts >= 10].index.tolist()
df_filtered = df_filtered[df_filtered[indicators_to_keep].sum(axis=1) > 0]
# Exclude non-feature columns for dimensionality reduction
columns_to_drop = ['subfolder', 'text_id', 'sentence_id', 'text_date', 'text_source', 'text_text_type']
if not include_modality:
columns_to_drop += [col for col in self.df.columns if col.startswith('modality_')]
features = df_filtered.drop(columns=columns_to_drop, errors='ignore').select_dtypes(include=[float, int])
features_clean = features.fillna(0)
# Prepare metadata for plotting
metadata = df_filtered[['subfolder']].copy()
metadata['indicator'] = df_filtered[indicators_to_keep].apply(
lambda row: ', '.join([indicator.replace('indicator_', '') for indicator in indicators_to_keep if row[indicator] > 0]),
axis=1
)
metadata['cause'] = df_filtered[self.cause_columns].apply(
lambda row: ', '.join([cause.replace('cause_', '') for cause in self.cause_columns if row[cause] > 0]),
axis=1
)
# Perform UMAP dimensionality reduction
reducer = umap.UMAP(n_components=2, random_state=42, n_neighbors=50, n_jobs=1, metric='cosine')
reduced_features = reducer.fit_transform(features_clean)
df_reduced = pd.DataFrame(reduced_features, columns=['UMAP x', 'UMAP y'])
df_reduced = pd.concat([df_reduced, metadata.reset_index(drop=True)], axis=1)
# Plotting the scatter plot
hover_data = {'cause': True, 'UMAP x': False, 'UMAP y': False}
if include_modality:
hover_data['Modality'] = True
fig = px.scatter(
df_reduced,
x='UMAP x',
y='UMAP y',
color='subfolder',
symbol='indicator',
labels={'subfolder': 'Effect'},
hover_data=hover_data,
color_discrete_sequence=px.colors.qualitative.D3
)
fig.update_layout(
xaxis=dict(showgrid=True),
yaxis=dict(showgrid=True),
showlegend=True,
legend=dict(title="Effect, Indicator", yanchor="top", xanchor="left", borderwidth=1),
)
return fig
def sankey(self, cause_threshold=10, indicator_threshold=5, link_opacity=0.4):
# Use self.df to avoid reloading data
df_filtered = self.df[(self.df[self.cause_columns].sum(axis=1) > 0) &
(self.df[self.indicator_columns].sum(axis=1) > 0)]
# Melt causes and indicators separately, ensuring unique sentence IDs
cause_data = df_filtered[['text_id', 'subfolder'] + self.cause_columns].melt(
id_vars=['text_id', 'subfolder'], var_name='cause', value_name='count'
).query("count > 0").drop_duplicates(['text_id', 'cause'])
indicator_data = df_filtered[['text_id', 'subfolder'] + self.indicator_columns].melt(
id_vars=['text_id', 'subfolder'], var_name='indicator', value_name='count'
).query("count > 0").drop_duplicates(['text_id', 'indicator'])
# Apply threshold filters
valid_causes = cause_data['cause'].value_counts()[lambda x: x >= cause_threshold].index
valid_indicators = indicator_data['indicator'].value_counts()[lambda x: x >= indicator_threshold].index
cause_data = cause_data[cause_data['cause'].isin(valid_causes)]
indicator_data = indicator_data[indicator_data['indicator'].isin(valid_indicators)]
# Create unique cause-indicator-subfolder links by merging cause and indicator data on 'text_id' and 'subfolder'
cause_indicator_links = (
cause_data.merge(indicator_data, on=['text_id', 'subfolder'])
.groupby(['cause', 'indicator']).size().reset_index(name='count')
)
# Aggregate indicator-subfolder counts
indicator_subfolder_links = (
indicator_data.groupby(['indicator', 'subfolder']).size().reset_index(name='count')
)
# Define unique labels and their order
all_labels = list(valid_causes) + list(valid_indicators) + self.df['subfolder'].unique().tolist()
# Remove prefixes for cleaner labels
all_labels_cleaned = [label.replace("cause_", "").replace("indicator_", "") for label in all_labels]
label_to_index = {label: idx for idx, label in enumerate(all_labels)}
# Define a color palette from Plotly's D3 color sequence
color_palette = px.colors.qualitative.D3
node_colors = [color_palette[i % len(color_palette)] for i in range(len(all_labels))]
# Define sources, targets, values, and link colors with RGBA opacity
sources, targets, values, link_colors = [], [], [], []
def hex_to_rgba(hex_color, opacity):
return f'rgba({int(hex_color[1:3], 16)}, {int(hex_color[3:5], 16)}, {int(hex_color[5:], 16)}, {opacity})'
# Cause -> Indicator links
for _, row in cause_indicator_links.iterrows():
if row['cause'] in label_to_index and row['indicator'] in label_to_index:
source_idx = label_to_index[row['cause']]
target_idx = label_to_index[row['indicator']]
sources.append(source_idx)
targets.append(target_idx)
values.append(row['count'])
link_colors.append(hex_to_rgba(node_colors[source_idx], link_opacity))
# Indicator -> Subfolder links
for _, row in indicator_subfolder_links.iterrows():
if row['indicator'] in label_to_index and row['subfolder'] in label_to_index:
source_idx = label_to_index[row['indicator']]
target_idx = label_to_index[row['subfolder']]
sources.append(source_idx)
targets.append(target_idx)
values.append(row['count'])
link_colors.append(hex_to_rgba(node_colors[source_idx], link_opacity))
fig = go.Figure(data=[go.Sankey(
node=dict(
pad=15,
thickness=20,
line=dict(color="black", width=0.5),
label=all_labels_cleaned,
color=node_colors
),
link=dict(
source=sources,
target=targets,
value=values,
color=link_colors
)
)])
fig.update_layout(
autosize=False,
width=800,
height=600,
font=dict(size=10)
)
return fig