import plotly.express as px import gradio as gr import plotly.graph_objects as go import seaborn as sns import pandas as pd import matplotlib.pyplot as plt from matplotlib.ticker import PercentFormatter def plot_wow_retention_by_type(wow_retention): wow_retention["week"] = pd.to_datetime(wow_retention["week"]) wow_retention = wow_retention.sort_values(["trader_type", "week"]) fig = px.line( wow_retention, x="week", y="retention_rate", color="trader_type", markers=True, title="Weekly Retention Rate by Trader Type", labels={ "week": "Week", "retention_rate": "Retention Rate (%)", "trader_type": "Trader Type", }, color_discrete_sequence=["purple", "goldenrod", "green"], ) fig.update_layout( hovermode="x unified", legend=dict( yanchor="middle", y=0.5, xanchor="left", x=0.99, orientation="v", ), yaxis=dict( ticksuffix="%", range=[ 0, max(wow_retention["retention_rate"]) * 1.1, ], # Add 10% padding to y-axis ), xaxis=dict(tickformat="%Y-%m-%d"), margin=dict(r=200), # Adjusted margins width=600, # Set explicit width height=500, # Set explicit height ) # Add hover template fig.update_traces( hovertemplate="%{y:.1f}%
Week: %{x|%Y-%m-%d}" ) return gr.Plot( value=fig, ) def plot_cohort_retention_heatmap(retention_matrix: pd.DataFrame, cmap: str): # Create a copy of the matrix to avoid modifying the original retention_matrix = retention_matrix.copy() # Convert index to datetime and format to date string retention_matrix.index = pd.to_datetime(retention_matrix.index).strftime("%a-%b %d") # Create figure and axes with specified size plt.figure(figsize=(12, 8)) # Create mask for NaN values mask = retention_matrix.isna() # Create heatmap ax = sns.heatmap( data=retention_matrix, annot=True, # Show numbers in cells fmt=".1f", # Format numbers to 1 decimal place cmap=cmap, # Yellow to Orange to Red color scheme vmin=0, vmax=100, center=50, cbar_kws={"label": "Retention Rate (%)", "format": PercentFormatter()}, mask=mask, annot_kws={"size": 8}, ) # Customize the plot plt.title("Cohort Retention Analysis", pad=20, size=14) plt.xlabel("Weeks Since First Activiy", size=12) plt.ylabel("Cohort First Day of the Week", size=12) # Format week numbers on x-axis x_labels = [f"Week {i}" for i in retention_matrix.columns] ax.set_xticklabels(x_labels, rotation=45, ha="right") # Set y-axis labels rotation plt.yticks(rotation=0) # Add gridlines ax.set_axisbelow(True) # Adjust layout to prevent label cutoff plt.tight_layout() cohort_fig = ax.get_figure() return gr.Plot(value=cohort_fig)