|
import gradio as gr |
|
import numpy as np |
|
import pandas as pd |
|
from sklearn.cluster import KMeans |
|
from sklearn.metrics import pairwise_distances_argmin_min |
|
import matplotlib.pyplot as plt |
|
import matplotlib.cm |
|
import io |
|
import os |
|
from PIL import Image |
|
|
|
|
|
EXAMPLE_DATA_DIR = "eg_data" |
|
EXAMPLE_FILES = { |
|
"cashflow_base": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K.xlsx"), |
|
"cashflow_lapse": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K_lapse50.xlsx"), |
|
"cashflow_mort": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K_mort15.xlsx"), |
|
"policy_data": os.path.join(EXAMPLE_DATA_DIR, "model_point_table.xlsx"), |
|
"pv_base": os.path.join(EXAMPLE_DATA_DIR, "pv_seriatim_10K.xlsx"), |
|
"pv_lapse": os.path.join(EXAMPLE_DATA_DIR, "pv_seriatim_10K_lapse50.xlsx"), |
|
"pv_mort": os.path.join(EXAMPLE_DATA_DIR, "pv_seriatim_10K_mort15.xlsx"), |
|
} |
|
|
|
class Clusters: |
|
def __init__(self, loc_vars): |
|
|
|
if loc_vars.empty: |
|
raise ValueError("Input data for KMeans (loc_vars) is empty.") |
|
if loc_vars.isnull().all().all(): |
|
raise ValueError("Input data for KMeans (loc_vars) contains all NaN values.") |
|
|
|
self.kmeans = KMeans(n_clusters=min(1000, len(loc_vars)), random_state=0, n_init=10).fit(np.ascontiguousarray(loc_vars)) |
|
closest, _ = pairwise_distances_argmin_min(self.kmeans.cluster_centers_, np.ascontiguousarray(loc_vars)) |
|
|
|
rep_ids = pd.Series(data=(closest + 1)) |
|
rep_ids.name = 'policy_id' |
|
rep_ids.index.name = 'cluster_id' |
|
self.rep_ids = rep_ids |
|
|
|
self.policy_count = self.agg_by_cluster(pd.DataFrame({'policy_count': [1] * len(loc_vars)}))['policy_count'] |
|
|
|
def agg_by_cluster(self, df, agg=None): |
|
temp = df.copy() |
|
temp['cluster_id'] = self.kmeans.labels_ |
|
temp = temp.set_index('cluster_id') |
|
|
|
|
|
if agg is not None and not isinstance(agg, dict): |
|
|
|
|
|
|
|
agg_ops = {col: "sum" for col in temp.columns} |
|
elif isinstance(agg, dict): |
|
agg_ops = {c: (agg[c] if c in agg else 'sum') for c in temp.columns} |
|
else: |
|
agg_ops = "sum" |
|
|
|
return temp.groupby(temp.index).agg(agg_ops) |
|
|
|
def extract_reps(self, df): |
|
temp = pd.merge(self.rep_ids, df.reset_index(), how='left', on='policy_id') |
|
temp.index.name = 'cluster_id' |
|
return temp.drop('policy_id', axis=1) |
|
|
|
def extract_and_scale_reps(self, df, agg=None): |
|
extracted_df = self.extract_reps(df) |
|
if extracted_df.empty: |
|
return extracted_df |
|
|
|
if agg and isinstance(agg, dict): |
|
|
|
|
|
|
|
|
|
|
|
scaled_df = extracted_df.copy() |
|
for c in extracted_df.columns: |
|
if agg.get(c, 'sum') == 'sum': |
|
scaled_df[c] = extracted_df[c].mul(self.policy_count, axis=0) |
|
|
|
return scaled_df |
|
else: |
|
return extracted_df.mul(self.policy_count, axis=0) |
|
|
|
def compare(self, df, agg=None): |
|
source = self.agg_by_cluster(df, agg) |
|
target = self.extract_and_scale_reps(df, agg) |
|
|
|
|
|
|
|
|
|
if agg and isinstance(agg, dict): |
|
agg_ops_for_target = {} |
|
for col, method in agg.items(): |
|
if method == 'sum': |
|
agg_ops_for_target[col] = 'sum' |
|
elif method == 'mean': |
|
|
|
|
|
|
|
|
|
|
|
|
|
pass |
|
else: |
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return pd.DataFrame({'actual': source.stack(), 'estimate': target.stack()}) |
|
|
|
|
|
def compare_total(self, df, agg=None): |
|
"""Aggregate df by columns and compare actual vs estimate totals.""" |
|
if df.empty: |
|
return pd.DataFrame(columns=['actual', 'estimate', 'error']) |
|
|
|
|
|
op_for_actual = {} |
|
if isinstance(agg, dict): |
|
for c in df.columns: |
|
op_for_actual[c] = agg.get(c, 'sum') |
|
else: |
|
for c in df.columns: |
|
if pd.api.types.is_numeric_dtype(df[c]): |
|
op_for_actual[c] = 'sum' |
|
|
|
|
|
actual = df.agg(op_for_actual) |
|
actual = actual.dropna() |
|
|
|
|
|
reps_values = self.extract_reps(df) |
|
if reps_values.empty: |
|
estimate = pd.Series(index=actual.index, dtype=float) |
|
else: |
|
estimate_values = {} |
|
for col_name in actual.index: |
|
col_op = op_for_actual.get(col_name, 'sum') |
|
|
|
if col_name not in reps_values.columns: |
|
estimate_values[col_name] = np.nan |
|
continue |
|
|
|
rep_col_values = reps_values[col_name] |
|
|
|
if col_op == 'sum': |
|
|
|
estimate_values[col_name] = (rep_col_values * self.policy_count).sum() |
|
elif col_op == 'mean': |
|
|
|
weighted_sum = (rep_col_values * self.policy_count).sum() |
|
total_weight = self.policy_count.sum() |
|
estimate_values[col_name] = weighted_sum / total_weight if total_weight != 0 else np.nan |
|
else: |
|
estimate_values[col_name] = np.nan |
|
|
|
estimate = pd.Series(estimate_values, index=actual.index) |
|
|
|
|
|
|
|
actual_aligned, estimate_aligned = actual.align(estimate, join='inner') |
|
|
|
error = pd.Series(index=actual_aligned.index, dtype=float) |
|
|
|
|
|
valid_mask = (actual_aligned != 0) & (~actual_aligned.isna()) |
|
error[valid_mask] = estimate_aligned[valid_mask] / actual_aligned[valid_mask] - 1 |
|
|
|
|
|
actual_zero_mask = (actual_aligned == 0) & (~actual_aligned.isna()) |
|
|
|
error[actual_zero_mask & (estimate_aligned == 0)] = 0 |
|
|
|
error[actual_zero_mask & (estimate_aligned != 0)] = np.inf |
|
|
|
|
|
error = error.replace([np.inf, -np.inf], np.nan) |
|
|
|
result_df = pd.DataFrame({'actual': actual_aligned, 'estimate': estimate_aligned, 'error': error}) |
|
return result_df |
|
|
|
|
|
def plot_cashflows_comparison(cfs_list, cluster_obj, titles): |
|
if not cfs_list or not cluster_obj or not titles or len(cfs_list) == 0: |
|
fig, ax = plt.subplots() |
|
ax.text(0.5, 0.5, "No data for cashflow comparison plot.", ha='center', va='center') |
|
buf = io.BytesIO(); plt.savefig(buf, format='png'); buf.seek(0); img = Image.open(buf); plt.close(fig); return img |
|
|
|
num_plots = len(cfs_list) |
|
cols = 2 |
|
rows = (num_plots + cols - 1) // cols |
|
|
|
fig, axes = plt.subplots(rows, cols, figsize=(15, 5 * rows), squeeze=False) |
|
axes = axes.flatten() |
|
|
|
plot_made = False |
|
for i, (df_cf, title) in enumerate(zip(cfs_list, titles)): |
|
if i < len(axes): |
|
if df_cf is None or df_cf.empty: |
|
axes[i].text(0.5,0.5, f"No data for {title}", ha='center', va='center') |
|
axes[i].set_title(title) |
|
continue |
|
comparison = cluster_obj.compare_total(df_cf) |
|
if not comparison.empty and 'actual' in comparison and 'estimate' in comparison: |
|
comparison[['actual', 'estimate']].plot(ax=axes[i], grid=True, title=title) |
|
axes[i].set_xlabel('Time') |
|
axes[i].set_ylabel('Value') |
|
plot_made = True |
|
else: |
|
axes[i].text(0.5,0.5, f"Could not generate comparison for {title}", ha='center', va='center') |
|
axes[i].set_title(title) |
|
|
|
for j in range(i + 1, len(axes)): |
|
fig.delaxes(axes[j]) |
|
|
|
if not plot_made: |
|
plt.close(fig) |
|
fig, ax = plt.subplots() |
|
ax.text(0.5, 0.5, "Insufficient data for any cashflow plots.", ha='center', va='center') |
|
|
|
|
|
plt.tight_layout() |
|
buf = io.BytesIO() |
|
plt.savefig(buf, format='png', dpi=100) |
|
buf.seek(0) |
|
img = Image.open(buf) |
|
plt.close(fig) |
|
return img |
|
|
|
def plot_scatter_comparison(df_compare_output, title): |
|
if df_compare_output is None or df_compare_output.empty: |
|
fig, ax = plt.subplots(figsize=(10,6)); ax.text(0.5, 0.5, "No data for scatter plot.", ha='center', va='center'); ax.set_title(title) |
|
buf = io.BytesIO(); plt.savefig(buf, format='png'); buf.seek(0); img = Image.open(buf); plt.close(fig); return img |
|
|
|
fig, ax = plt.subplots(figsize=(10, 6)) |
|
|
|
if not isinstance(df_compare_output.index, pd.MultiIndex) or df_compare_output.index.nlevels < 2: |
|
|
|
ax.scatter(df_compare_output.get('actual', []), df_compare_output.get('estimate', []), s=9, alpha=0.6) |
|
else: |
|
unique_levels = df_compare_output.index.get_level_values(1).unique() |
|
colors = matplotlib.cm.rainbow(np.linspace(0, 1, len(unique_levels))) |
|
|
|
for item_level, color_val in zip(unique_levels, colors): |
|
subset = df_compare_output.xs(item_level, level=1) |
|
ax.scatter(subset['actual'], subset['estimate'], color=color_val, s=9, alpha=0.6, label=str(item_level)) |
|
if len(unique_levels) > 1 and len(unique_levels) <=10: |
|
ax.legend(title=df_compare_output.index.names[1]) |
|
|
|
ax.set_xlabel('Actual') |
|
ax.set_ylabel('Estimate') |
|
ax.set_title(title) |
|
ax.grid(True) |
|
|
|
try: |
|
current_xlim = ax.get_xlim() |
|
current_ylim = ax.get_ylim() |
|
lims = [ |
|
np.nanmin([current_xlim, current_ylim]), |
|
np.nanmax([current_xlim, current_ylim]), |
|
] |
|
if lims[0] != lims[1] and not np.isnan(lims[0]) and not np.isnan(lims[1]): |
|
ax.plot(lims, lims, 'r-', linewidth=0.5) |
|
ax.set_xlim(lims) |
|
ax.set_ylim(lims) |
|
except Exception: |
|
pass |
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format='png', dpi=100) |
|
buf.seek(0) |
|
img = Image.open(buf) |
|
plt.close(fig) |
|
return img |
|
|
|
|
|
def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path, |
|
policy_data_path, pv_base_path, pv_lapse_path, pv_mort_path): |
|
results = {} |
|
try: |
|
cfs = pd.read_excel(cashflow_base_path, index_col=0) |
|
cfs_lapse50 = pd.read_excel(cashflow_lapse_path, index_col=0) |
|
cfs_mort15 = pd.read_excel(cashflow_mort_path, index_col=0) |
|
|
|
pol_data_full = pd.read_excel(policy_data_path, index_col=0) |
|
required_cols = ['age_at_entry', 'policy_term', 'sum_assured', 'duration_mth'] |
|
missing_policy_cols = [col for col in required_cols if col not in pol_data_full.columns] |
|
if missing_policy_cols: |
|
gr.Warning(f"Policy data is missing required columns: {', '.join(missing_policy_cols)}. Analysis may be affected.") |
|
pol_data = pol_data_full |
|
else: |
|
pol_data = pol_data_full[required_cols] |
|
|
|
pvs = pd.read_excel(pv_base_path, index_col=0) |
|
pvs_lapse50 = pd.read_excel(pv_lapse_path, index_col=0) |
|
pvs_mort15 = pd.read_excel(pv_mort_path, index_col=0) |
|
|
|
cfs_list = [cfs, cfs_lapse50, cfs_mort15] |
|
scen_titles = ['Base', 'Lapse+50%', 'Mort+15%'] |
|
|
|
mean_attrs_agg = {'age_at_entry':'mean', 'policy_term':'mean', 'duration_mth':'mean', 'sum_assured': 'sum'} |
|
|
|
|
|
gr.Info("Starting Cashflow Calibration...") |
|
if cfs.empty: gr.Warning("Base cashflow data is empty for Cashflow Calibration.") |
|
cluster_cfs = Clusters(cfs) |
|
results['cf_total_base_table'] = cluster_cfs.compare_total(cfs) |
|
results['cf_policy_attrs_total'] = cluster_cfs.compare_total(pol_data, agg=mean_attrs_agg) |
|
results['cf_pv_total_base'] = cluster_cfs.compare_total(pvs) |
|
results['cf_pv_total_lapse'] = cluster_cfs.compare_total(pvs_lapse50) |
|
results['cf_pv_total_mort'] = cluster_cfs.compare_total(pvs_mort15) |
|
results['cf_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_cfs, scen_titles) |
|
results['cf_scatter_cashflows_base'] = plot_scatter_comparison(cluster_cfs.compare(cfs), 'CF Calib. - Cashflows (Base)') |
|
gr.Info("Cashflow Calibration Done.") |
|
|
|
|
|
gr.Info("Starting Policy Attribute Calibration...") |
|
if pol_data.empty : |
|
gr.Warning("Policy data is empty. Skipping Policy Attribute Calibration.") |
|
loc_vars_attrs = pd.DataFrame() |
|
else: |
|
pol_data_min = pol_data.min() |
|
pol_data_range = pol_data.max() - pol_data_min |
|
|
|
if (pol_data_range == 0).any(): |
|
gr.Warning("Some policy attributes have no variance (all values are the same). Standardization might be affected.") |
|
|
|
|
|
|
|
|
|
pol_data_range[pol_data_range == 0] = 1 |
|
loc_vars_attrs = (pol_data - pol_data_min) / pol_data_range |
|
loc_vars_attrs = loc_vars_attrs.fillna(0) |
|
|
|
if not loc_vars_attrs.empty: |
|
cluster_attrs = Clusters(loc_vars_attrs) |
|
results['attr_total_cf_base'] = cluster_attrs.compare_total(cfs) |
|
results['attr_policy_attrs_total'] = cluster_attrs.compare_total(pol_data, agg=mean_attrs_agg) |
|
results['attr_total_pv_base'] = cluster_attrs.compare_total(pvs) |
|
results['attr_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_attrs, scen_titles) |
|
results['attr_scatter_cashflows_base'] = plot_scatter_comparison(cluster_attrs.compare(cfs), 'Attr Calib. - Cashflows (Base)') |
|
else: |
|
results.update({k: pd.DataFrame() for k in ['attr_total_cf_base', 'attr_policy_attrs_total', 'attr_total_pv_base']}) |
|
results.update({k: None for k in ['attr_cashflow_plot', 'attr_scatter_cashflows_base']}) |
|
gr.Info("Policy Attribute Calibration Done.") |
|
|
|
|
|
gr.Info("Starting Present Value Calibration...") |
|
if pvs.empty: gr.Warning("Base Present Value data is empty for PV Calibration.") |
|
cluster_pvs = Clusters(pvs) |
|
results['pv_total_cf_base'] = cluster_pvs.compare_total(cfs) |
|
results['pv_policy_attrs_total'] = cluster_pvs.compare_total(pol_data, agg=mean_attrs_agg) |
|
results['pv_total_pv_base'] = cluster_pvs.compare_total(pvs) |
|
results['pv_total_pv_lapse'] = cluster_pvs.compare_total(pvs_lapse50) |
|
results['pv_total_pv_mort'] = cluster_pvs.compare_total(pvs_mort15) |
|
results['pv_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_pvs, scen_titles) |
|
results['pv_scatter_pvs_base'] = plot_scatter_comparison(cluster_pvs.compare(pvs), 'PV Calib. - PVs (Base)') |
|
gr.Info("Present Value Calibration Done.") |
|
|
|
|
|
gr.Info("Generating Summary Plot...") |
|
error_data = {} |
|
pv_col_name = 'PV_NetCF' |
|
|
|
for calib_prefix, cluster_obj, calib_name_display in [ |
|
('CF Calib.', cluster_cfs, "CF Calib."), |
|
('Attr Calib.', globals().get('cluster_attrs'), "Attr Calib."), |
|
('PV Calib.', cluster_pvs, "PV Calib.")]: |
|
|
|
current_calib_errors = [] |
|
if cluster_obj is None and calib_prefix == 'Attr Calib.': |
|
current_calib_errors = [np.nan, np.nan, np.nan] |
|
else: |
|
for pv_df_scenario in [pvs, pvs_lapse50, pvs_mort15]: |
|
if pv_df_scenario.empty: |
|
current_calib_errors.append(np.nan) |
|
continue |
|
|
|
comp_total_df = cluster_obj.compare_total(pv_df_scenario) |
|
if pv_col_name in comp_total_df.index: |
|
error_val = comp_total_df.loc[pv_col_name, 'error'] |
|
elif not comp_total_df.empty and 'error' in comp_total_df.columns: |
|
error_val = comp_total_df['error'].mean() |
|
if calib_prefix == 'CF Calib.' and pv_df_scenario is pvs: |
|
gr.Warning(f"'{pv_col_name}' not found for summary plot. Using mean error of all PV columns instead for {calib_name_display}.") |
|
else: |
|
error_val = np.nan |
|
current_calib_errors.append(abs(error_val)) |
|
error_data[calib_name_display] = current_calib_errors |
|
|
|
summary_df = pd.DataFrame(error_data, index=['Base', 'Lapse+50%', 'Mort+15%']) |
|
|
|
fig_summary, ax_summary = plt.subplots(figsize=(10, 6)) |
|
|
|
plot_title = f'Calibration Method Comparison - Abs. Error in Total {pv_col_name}' |
|
if summary_df.isnull().all().all(): |
|
ax_summary.text(0.5, 0.5, f"Error data for summary is N/A.\nCheck input PV files for '{pv_col_name}' column and valid numeric data.", |
|
ha='center', va='center', transform=ax_summary.transAxes, wrap=True) |
|
ax_summary.set_title(plot_title) |
|
elif summary_df.empty: |
|
ax_summary.text(0.5, 0.5, "No summary data to plot.", ha='center', va='center') |
|
ax_summary.set_title(plot_title) |
|
else: |
|
summary_df.plot(kind='bar', ax=ax_summary, grid=True) |
|
ax_summary.set_ylabel(f'Mean Absolute Error (of {pv_col_name} or fallback)') |
|
ax_summary.set_title(plot_title) |
|
ax_summary.tick_params(axis='x', rotation=0) |
|
|
|
plt.tight_layout() |
|
buf_summary = io.BytesIO(); plt.savefig(buf_summary, format='png', dpi=100); buf_summary.seek(0) |
|
results['summary_plot'] = Image.open(buf_summary) |
|
plt.close(fig_summary) |
|
gr.Info("All processing complete.") |
|
return results |
|
|
|
except FileNotFoundError as e: |
|
gr.Error(f"File not found: {e.filename}. Ensure example files are in '{EXAMPLE_DATA_DIR}' or all files are uploaded correctly.") |
|
return {"error": f"File not found: {e.filename}"} |
|
except ValueError as e: |
|
gr.Error(f"Data validation error: {str(e)}") |
|
return {"error": f"Data error: {str(e)}"} |
|
except KeyError as e: |
|
gr.Error(f"A required column is missing: {e}. Please check data formats, especially index columns and expected data columns like 'PV_NetCF'.") |
|
return {"error": f"Missing column: {e}"} |
|
except Exception as e: |
|
gr.Error(f"An unexpected error occurred during processing: {str(e)}") |
|
import traceback |
|
traceback.print_exc() |
|
return {"error": f"Processing error: {str(e)}"} |
|
|
|
|
|
def create_interface(): |
|
with gr.Blocks(title="Cluster Model Points Analysis") as demo: |
|
gr.Markdown(""" |
|
# Cluster Model Points Analysis |
|
This application applies k-means cluster analysis to select representative model points from an insurance portfolio. |
|
Upload your Excel files or use the example data to analyze results based on different calibration variable choices. |
|
**Required Excel (.xlsx) Files:** |
|
- Cashflows - Base Scenario |
|
- Cashflows - Lapse Stress (+50%) |
|
- Cashflows - Mortality Stress (+15%) |
|
- Policy Data (must include 'age_at_entry', 'policy_term', 'sum_assured', 'duration_mth', and an index column for `policy_id`) |
|
- Present Values - Base Scenario (ideally with a 'PV_NetCF' column and an index column for `policy_id`) |
|
- Present Values - Lapse Stress (same structure as Base PV) |
|
- Present Values - Mortality Stress (same structure as Base PV) |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
gr.Markdown("### 📂 Upload Files or Load Examples") |
|
load_example_btn = gr.Button("Load Example Data", icon="💾") |
|
with gr.Row(): |
|
cashflow_base_input = gr.File(label="Cashflows - Base", file_types=[".xlsx"]) |
|
cashflow_lapse_input = gr.File(label="Cashflows - Lapse Stress", file_types=[".xlsx"]) |
|
cashflow_mort_input = gr.File(label="Cashflows - Mortality Stress", file_types=[".xlsx"]) |
|
with gr.Row(): |
|
policy_data_input = gr.File(label="Policy Data", file_types=[".xlsx"]) |
|
pv_base_input = gr.File(label="Present Values - Base", file_types=[".xlsx"]) |
|
pv_lapse_input = gr.File(label="Present Values - Lapse Stress", file_types=[".xlsx"]) |
|
with gr.Row(): |
|
pv_mort_input = gr.File(label="Present Values - Mortality Stress", file_types=[".xlsx"]) |
|
span_dummy = gr.File(visible=False) |
|
span_dummy2 = gr.File(visible=False) |
|
|
|
|
|
analyze_btn = gr.Button("Analyze Dataset", variant="primary", icon="🚀", scale=1) |
|
|
|
with gr.Tabs(): |
|
with gr.TabItem("📊 Summary"): |
|
summary_plot_output = gr.Image(label="Calibration Methods Comparison") |
|
|
|
with gr.TabItem("💸 Cashflow Calibration"): |
|
gr.Markdown("### Results: Using Annual Cashflows (Base) as Calibration Variables") |
|
with gr.Row(): |
|
cf_total_base_table_out = gr.Dataframe(label="Overall Comparison - Base CF", wrap=True) |
|
cf_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes", wrap=True) |
|
cf_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate)") |
|
cf_scatter_cashflows_base_out = gr.Image(label="Scatter: Per-Cluster Cashflows (Base)") |
|
with gr.Accordion("Present Value Comparisons (Totals)", open=False): |
|
with gr.Row(): |
|
cf_pv_total_base_out = gr.Dataframe(label="PVs - Base", wrap=True) |
|
cf_pv_total_lapse_out = gr.Dataframe(label="PVs - Lapse Stress", wrap=True) |
|
cf_pv_total_mort_out = gr.Dataframe(label="PVs - Mortality Stress", wrap=True) |
|
|
|
with gr.TabItem("👤 Policy Attribute Calibration"): |
|
gr.Markdown("### Results: Using Policy Attributes as Calibration Variables") |
|
with gr.Row(): |
|
attr_total_cf_base_out = gr.Dataframe(label="Overall Comparison - Base CF", wrap=True) |
|
attr_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes", wrap=True) |
|
attr_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate)") |
|
attr_scatter_cashflows_base_out = gr.Image(label="Scatter: Per-Cluster Cashflows (Base)") |
|
with gr.Accordion("Present Value Comparisons (Totals)", open=False): |
|
attr_total_pv_base_out = gr.Dataframe(label="PVs - Base Scenario", wrap=True) |
|
|
|
with gr.TabItem("💰 Present Value Calibration"): |
|
gr.Markdown("### Results: Using Present Values (Base) as Calibration Variables") |
|
with gr.Row(): |
|
pv_total_cf_base_out = gr.Dataframe(label="Overall Comparison - Base CF", wrap=True) |
|
pv_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes", wrap=True) |
|
pv_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate)") |
|
pv_scatter_pvs_base_out = gr.Image(label="Scatter: Per-Cluster PVs (Base)") |
|
with gr.Accordion("Present Value Comparisons (Totals)", open=False): |
|
with gr.Row(): |
|
pv_total_pv_base_out = gr.Dataframe(label="PVs - Base", wrap=True) |
|
pv_total_pv_lapse_out = gr.Dataframe(label="PVs - Lapse Stress", wrap=True) |
|
pv_total_pv_mort_out = gr.Dataframe(label="PVs - Mortality Stress", wrap=True) |
|
|
|
output_components = [ |
|
summary_plot_output, |
|
cf_total_base_table_out, cf_policy_attrs_total_out, cf_cashflow_plot_out, cf_scatter_cashflows_base_out, |
|
cf_pv_total_base_out, cf_pv_total_lapse_out, cf_pv_total_mort_out, |
|
attr_total_cf_base_out, attr_policy_attrs_total_out, attr_cashflow_plot_out, attr_scatter_cashflows_base_out, attr_total_pv_base_out, |
|
pv_total_cf_base_out, pv_policy_attrs_total_out, pv_cashflow_plot_out, pv_scatter_pvs_base_out, |
|
pv_total_pv_base_out, pv_total_pv_lapse_out, pv_total_pv_mort_out |
|
] |
|
|
|
def handle_analysis_click(f1, f2, f3, f4, f5, f6, f7): |
|
all_files_present = all(f is not None for f in [f1, f2, f3, f4, f5, f6, f7]) |
|
if not all_files_present: |
|
gr.Warning("Not all files have been provided. Please upload all 7 files or load example data.") |
|
return [None] * len(output_components) |
|
|
|
|
|
file_paths = [] |
|
for f_obj in [f1, f2, f3, f4, f5, f6, f7]: |
|
if hasattr(f_obj, 'name') and isinstance(f_obj.name, str): |
|
file_paths.append(f_obj.name) |
|
elif isinstance(f_obj, str): |
|
file_paths.append(f_obj) |
|
else: |
|
gr.Error(f"Invalid file input: {f_obj}. Please re-upload or reload examples.") |
|
return [None] * len(output_components) |
|
|
|
analysis_results = process_files(*file_paths) |
|
|
|
if "error" in analysis_results: |
|
return [None] * len(output_components) |
|
|
|
|
|
return [ |
|
analysis_results.get('summary_plot'), |
|
analysis_results.get('cf_total_base_table'), analysis_results.get('cf_policy_attrs_total'), |
|
analysis_results.get('cf_cashflow_plot'), analysis_results.get('cf_scatter_cashflows_base'), |
|
analysis_results.get('cf_pv_total_base'), analysis_results.get('cf_pv_total_lapse'), analysis_results.get('cf_pv_total_mort'), |
|
analysis_results.get('attr_total_cf_base'), analysis_results.get('attr_policy_attrs_total'), |
|
analysis_results.get('attr_cashflow_plot'), analysis_results.get('attr_scatter_cashflows_base'), analysis_results.get('attr_total_pv_base'), |
|
analysis_results.get('pv_total_cf_base'), analysis_results.get('pv_policy_attrs_total'), |
|
analysis_results.get('pv_cashflow_plot'), analysis_results.get('pv_scatter_pvs_base'), |
|
analysis_results.get('pv_total_pv_base'), analysis_results.get('pv_total_pv_lapse'), analysis_results.get('pv_total_pv_mort') |
|
] |
|
|
|
analyze_btn.click( |
|
handle_analysis_click, |
|
inputs=[cashflow_base_input, cashflow_lapse_input, cashflow_mort_input, |
|
policy_data_input, pv_base_input, pv_lapse_input, pv_mort_input], |
|
outputs=output_components |
|
) |
|
|
|
input_file_components = [ |
|
cashflow_base_input, cashflow_lapse_input, cashflow_mort_input, |
|
policy_data_input, pv_base_input, pv_lapse_input, pv_mort_input |
|
] |
|
def load_example_files_action(): |
|
missing_example_files = [fp for fp in EXAMPLE_FILES.values() if not os.path.exists(fp)] |
|
if missing_example_files: |
|
gr.Error(f"Missing example data files in '{EXAMPLE_DATA_DIR}': {', '.join(missing_example_files)}. Please ensure they exist.") |
|
return [None] * len(input_file_components) |
|
gr.Info(f"Example data paths loaded from '{EXAMPLE_DATA_DIR}'. Click 'Analyze Dataset'.") |
|
return [ |
|
EXAMPLE_FILES["cashflow_base"], EXAMPLE_FILES["cashflow_lapse"], EXAMPLE_FILES["cashflow_mort"], |
|
EXAMPLE_FILES["policy_data"], EXAMPLE_FILES["pv_base"], EXAMPLE_FILES["pv_lapse"], |
|
EXAMPLE_FILES["pv_mort"] |
|
] |
|
load_example_btn.click(load_example_files_action, inputs=[], outputs=input_file_components) |
|
return demo |
|
|
|
if __name__ == "__main__": |
|
if not os.path.exists(EXAMPLE_DATA_DIR): |
|
try: |
|
os.makedirs(EXAMPLE_DATA_DIR) |
|
print(f"Created directory '{EXAMPLE_DATA_DIR}'. Please place example Excel files there.") |
|
print(f"Expected files: {list(EXAMPLE_FILES.keys())}") |
|
except OSError as e: |
|
print(f"Error creating directory {EXAMPLE_DATA_DIR}: {e}. Please create it manually.") |
|
|
|
print("Starting Gradio application...") |
|
print(f"Note: Ensure your example Excel files are placed in the '{os.getcwd()}{os.sep}{EXAMPLE_DATA_DIR}' folder.") |
|
print(f"Required policy data columns: 'age_at_entry', 'policy_term', 'sum_assured', 'duration_mth' (and an index col).") |
|
print(f"Recommended PV files column for summary: 'PV_NetCF' (and an index col).") |
|
|
|
demo_app = create_interface() |
|
demo_app.launch() |