|
import pandas as pd |
|
import plotly.express as px |
|
import plotly.graph_objects as go |
|
import os |
|
import umap |
|
import streamlit as st |
|
|
|
@st.cache_data |
|
def load_data(file_path): |
|
return pd.read_csv(file_path, sep='\t') |
|
|
|
class Plot: |
|
def __init__(self, data_file='data/feature_matrix.tsv', metadata_file='data/indicator_cause_sentence_metadata.tsv'): |
|
self.data_file = data_file |
|
self.metadata_file = metadata_file |
|
self.df = load_data(self.data_file) |
|
self.metadata_df = load_data(self.metadata_file) |
|
|
|
|
|
self.indicator_columns = [col for col in self.df.columns if col.startswith('indicator_')] |
|
self.cause_columns = [col for col in self.df.columns if col.startswith('cause_')] |
|
|
|
self.df['Year'] = self.df['text_date'].astype(str).str[:4] |
|
self.df['Has_Indicator'] = self.df[self.indicator_columns].sum(axis=1) > 0 |
|
|
|
|
|
self.total_sentences_per_year = self.df.groupby(['Year', 'subfolder']).size().reset_index(name='Total Sentences') |
|
self.total_sentences_per_subfolder = self.df.groupby('subfolder').size().reset_index(name='Total Sentences') |
|
|
|
def get_indicator_chart(self, chart_type='total', individual_threshold=5): |
|
if chart_type == 'total': |
|
|
|
indicator_counts = self.df[self.df['Has_Indicator']].groupby('subfolder').size().reset_index(name='Indicator Count') |
|
total_counts = indicator_counts.merge(self.total_sentences_per_subfolder, on='subfolder') |
|
total_counts['Indicator_Share'] = total_counts['Indicator Count'] / total_counts['Total Sentences'] |
|
total_counts['Indicator_Share_Text'] = (total_counts['Indicator_Share'] * 100).round(2).astype(str) + '%' |
|
|
|
fig = px.bar( |
|
total_counts, |
|
x='subfolder', |
|
y='Indicator_Share', |
|
labels={'Indicator_Share': 'Share of Sentences with Indicators', 'subfolder': ''}, |
|
color='subfolder', |
|
text='Indicator_Share_Text', |
|
color_discrete_sequence=px.colors.qualitative.D3 |
|
) |
|
fig.update_traces( |
|
textposition='inside', |
|
insidetextanchor='middle', |
|
texttemplate='%{text}', |
|
textfont=dict(color='rgb(255, 255, 255)') |
|
) |
|
|
|
elif chart_type == 'individual': |
|
|
|
df_melted = self.df.melt(id_vars=['subfolder'], value_vars=self.indicator_columns, var_name='Indicator', value_name='Count') |
|
df_melted = df_melted[df_melted['Count'] > 0] |
|
|
|
|
|
total_indicator_counts = df_melted.groupby('Indicator').size().reset_index(name='Total Count') |
|
indicators_meeting_threshold = total_indicator_counts[total_indicator_counts['Total Count'] >= individual_threshold]['Indicator'].unique() |
|
|
|
|
|
df_melted = df_melted[df_melted['Indicator'].isin(indicators_meeting_threshold)] |
|
df_melted['Indicator'] = df_melted['Indicator'].str.replace('indicator_', '').str.capitalize() |
|
|
|
|
|
df_melted = df_melted.groupby(['subfolder', 'Indicator']).size().reset_index(name='Count') |
|
|
|
|
|
fig = px.bar( |
|
df_melted, |
|
x='subfolder', |
|
y='Count', |
|
color='Indicator', |
|
barmode='group', |
|
labels={'Count': 'Occurrences', 'subfolder': '', 'Indicator': 'Indicator'}, |
|
color_discrete_sequence=px.colors.qualitative.D3 |
|
) |
|
fig.update_traces( |
|
texttemplate='%{y}', |
|
textposition='inside', |
|
insidetextanchor='middle', |
|
textfont=dict(color='rgb(255, 255, 255)') |
|
) |
|
|
|
elif chart_type == 'year': |
|
indicator_counts_per_year = self.df[self.df['Has_Indicator']].groupby(['Year', 'subfolder']).size().reset_index(name='Indicator Count') |
|
df_summary = pd.merge(self.total_sentences_per_year, indicator_counts_per_year, on=['Year', 'subfolder'], how='left') |
|
df_summary['Indicator_Share_Text'] = (df_summary['Indicator Count'] / df_summary['Total Sentences'] * 100).round(2).astype(str) + '%' |
|
|
|
fig = px.bar( |
|
df_summary, |
|
x='Year', |
|
y='Total Sentences', |
|
color='subfolder', |
|
labels={'Total Sentences': 'Total Number of Sentences', 'Year': 'Year'}, |
|
text='Indicator_Share_Text', |
|
color_discrete_sequence=px.colors.qualitative.D3 |
|
) |
|
fig.update_traces( |
|
textposition='inside', |
|
texttemplate='%{text}', |
|
insidetextanchor='middle', |
|
textfont=dict(color='rgb(255, 255, 255)') |
|
) |
|
|
|
fig.update_layout( |
|
xaxis=dict(showline=True), |
|
yaxis=dict(title='Indicator Sentences' if chart_type != 'year' else 'Total Sentences'), |
|
bargap=0.05, |
|
showlegend=(chart_type != 'total') |
|
) |
|
return fig |
|
|
|
def get_causes_chart(self, min_value=30): |
|
df_filtered = self.metadata_df[self.metadata_df['cause'] != 'N/A'] |
|
causes_meeting_threshold = df_filtered.groupby('cause')['cause'].count()[lambda x: x >= min_value].index |
|
df_filtered = df_filtered[df_filtered['cause'].isin(causes_meeting_threshold)] |
|
df_filtered['cause'] = df_filtered['cause'].str.capitalize() |
|
|
|
fig = px.bar( |
|
df_filtered.groupby(['subfolder', 'cause']).size().reset_index(name='Count'), |
|
x='subfolder', |
|
y='Count', |
|
color='cause', |
|
barmode='group', |
|
labels={'Count': 'Occurrences', 'subfolder': '', 'cause': 'Cause'}, |
|
color_discrete_sequence=px.colors.qualitative.D3 |
|
) |
|
fig.update_layout(xaxis=dict(showline=True), yaxis=dict(showticklabels=True, title='')) |
|
fig.update_traces( |
|
texttemplate='%{y}', |
|
textposition='inside', |
|
insidetextanchor='middle', |
|
textfont=dict(color='rgb(255, 255, 255)') |
|
) |
|
return fig |
|
|
|
def scatter(self, include_modality=False): |
|
|
|
df_filtered = self.df[(self.df[self.indicator_columns].sum(axis=1) > 0) | |
|
(self.df[self.cause_columns].sum(axis=1) > 0)] |
|
|
|
|
|
indicator_columns = [col for col in self.indicator_columns if 'indicator_!besprechen' not in col] |
|
indicator_counts = df_filtered[indicator_columns].sum() |
|
indicators_to_keep = indicator_counts[indicator_counts >= 10].index.tolist() |
|
df_filtered = df_filtered[df_filtered[indicators_to_keep].sum(axis=1) > 0] |
|
|
|
|
|
columns_to_drop = ['subfolder', 'text_id', 'sentence_id', 'text_date', 'text_source', 'text_text_type'] |
|
if not include_modality: |
|
columns_to_drop += [col for col in self.df.columns if col.startswith('modality_')] |
|
|
|
features = df_filtered.drop(columns=columns_to_drop, errors='ignore').select_dtypes(include=[float, int]) |
|
features_clean = features.fillna(0) |
|
|
|
|
|
metadata = df_filtered[['subfolder']].copy() |
|
metadata['indicator'] = df_filtered[indicators_to_keep].apply( |
|
lambda row: ', '.join([indicator.replace('indicator_', '') for indicator in indicators_to_keep if row[indicator] > 0]), |
|
axis=1 |
|
) |
|
metadata['cause'] = df_filtered[self.cause_columns].apply( |
|
lambda row: ', '.join([cause.replace('cause_', '') for cause in self.cause_columns if row[cause] > 0]), |
|
axis=1 |
|
) |
|
|
|
|
|
reducer = umap.UMAP(n_components=2, random_state=42, n_neighbors=50, n_jobs=1, metric='cosine') |
|
reduced_features = reducer.fit_transform(features_clean) |
|
df_reduced = pd.DataFrame(reduced_features, columns=['UMAP x', 'UMAP y']) |
|
df_reduced = pd.concat([df_reduced, metadata.reset_index(drop=True)], axis=1) |
|
|
|
|
|
hover_data = {'cause': True, 'UMAP x': False, 'UMAP y': False} |
|
if include_modality: |
|
hover_data['Modality'] = True |
|
|
|
fig = px.scatter( |
|
df_reduced, |
|
x='UMAP x', |
|
y='UMAP y', |
|
color='subfolder', |
|
symbol='indicator', |
|
labels={'subfolder': 'Effect'}, |
|
hover_data=hover_data, |
|
color_discrete_sequence=px.colors.qualitative.D3 |
|
) |
|
|
|
fig.update_layout( |
|
xaxis=dict(showgrid=True), |
|
yaxis=dict(showgrid=True), |
|
showlegend=True, |
|
legend=dict(title="Effect, Indicator", yanchor="top", xanchor="left", borderwidth=1), |
|
) |
|
|
|
return fig |
|
|
|
def sankey(self, cause_threshold=10, indicator_threshold=5, link_opacity=0.4): |
|
|
|
df_filtered = self.df[(self.df[self.cause_columns].sum(axis=1) > 0) & |
|
(self.df[self.indicator_columns].sum(axis=1) > 0)] |
|
|
|
|
|
cause_data = df_filtered[['text_id', 'subfolder'] + self.cause_columns].melt( |
|
id_vars=['text_id', 'subfolder'], var_name='cause', value_name='count' |
|
).query("count > 0").drop_duplicates(['text_id', 'cause']) |
|
|
|
indicator_data = df_filtered[['text_id', 'subfolder'] + self.indicator_columns].melt( |
|
id_vars=['text_id', 'subfolder'], var_name='indicator', value_name='count' |
|
).query("count > 0").drop_duplicates(['text_id', 'indicator']) |
|
|
|
|
|
valid_causes = cause_data['cause'].value_counts()[lambda x: x >= cause_threshold].index |
|
valid_indicators = indicator_data['indicator'].value_counts()[lambda x: x >= indicator_threshold].index |
|
cause_data = cause_data[cause_data['cause'].isin(valid_causes)] |
|
indicator_data = indicator_data[indicator_data['indicator'].isin(valid_indicators)] |
|
|
|
|
|
cause_indicator_links = ( |
|
cause_data.merge(indicator_data, on=['text_id', 'subfolder']) |
|
.groupby(['cause', 'indicator']).size().reset_index(name='count') |
|
) |
|
|
|
|
|
indicator_subfolder_links = ( |
|
indicator_data.groupby(['indicator', 'subfolder']).size().reset_index(name='count') |
|
) |
|
|
|
|
|
all_labels = list(valid_causes) + list(valid_indicators) + self.df['subfolder'].unique().tolist() |
|
|
|
|
|
all_labels_cleaned = [label.replace("cause_", "").replace("indicator_", "") for label in all_labels] |
|
label_to_index = {label: idx for idx, label in enumerate(all_labels)} |
|
|
|
|
|
color_palette = px.colors.qualitative.D3 |
|
node_colors = [color_palette[i % len(color_palette)] for i in range(len(all_labels))] |
|
|
|
|
|
sources, targets, values, link_colors = [], [], [], [] |
|
|
|
def hex_to_rgba(hex_color, opacity): |
|
return f'rgba({int(hex_color[1:3], 16)}, {int(hex_color[3:5], 16)}, {int(hex_color[5:], 16)}, {opacity})' |
|
|
|
|
|
for _, row in cause_indicator_links.iterrows(): |
|
if row['cause'] in label_to_index and row['indicator'] in label_to_index: |
|
source_idx = label_to_index[row['cause']] |
|
target_idx = label_to_index[row['indicator']] |
|
sources.append(source_idx) |
|
targets.append(target_idx) |
|
values.append(row['count']) |
|
link_colors.append(hex_to_rgba(node_colors[source_idx], link_opacity)) |
|
|
|
|
|
for _, row in indicator_subfolder_links.iterrows(): |
|
if row['indicator'] in label_to_index and row['subfolder'] in label_to_index: |
|
source_idx = label_to_index[row['indicator']] |
|
target_idx = label_to_index[row['subfolder']] |
|
sources.append(source_idx) |
|
targets.append(target_idx) |
|
values.append(row['count']) |
|
link_colors.append(hex_to_rgba(node_colors[source_idx], link_opacity)) |
|
|
|
fig = go.Figure(data=[go.Sankey( |
|
node=dict( |
|
pad=15, |
|
thickness=20, |
|
line=dict(color="black", width=0.5), |
|
label=all_labels_cleaned, |
|
color=node_colors |
|
), |
|
link=dict( |
|
source=sources, |
|
target=targets, |
|
value=values, |
|
color=link_colors |
|
) |
|
)]) |
|
|
|
fig.update_layout( |
|
autosize=False, |
|
width=800, |
|
height=600, |
|
font=dict(size=10) |
|
) |
|
|
|
return fig |
|
|