dibend's picture
Update app.py
62b932f verified
import os
import requests
import pandas as pd
import plotly.graph_objects as go
import gradio as gr
# Set FRED API Key from environment variable
FRED_API_KEY = os.getenv('FRED_API_KEY')
# List of FRED data series and their descriptive labels
series_options = {
"UNRATE": "Unemployment Rate",
"GDP": "Gross Domestic Product",
"CPIAUCSL": "Consumer Price Index for All Urban Consumers",
"DGS10": "10-Year Treasury Constant Maturity Rate",
"FEDFUNDS": "Effective Federal Funds Rate",
"M1SL": "M1 Money Supply",
"M2SL": "M2 Money Supply",
"M3SL": "M3 Money Supply",
"HOUST": "Housing Starts",
"PCE": "Personal Consumption Expenditures",
"BAA10YM": "Moody's Baa Corporate Bond Yield Spread"
}
# Function to fetch data from the FRED API
def fetch_fred_data(series_ids):
"""
Fetches data for a list of FRED series IDs.
Returns a DataFrame with columns as series and dates as rows.
"""
data = {}
for series_id in series_ids:
response = requests.get(
f'https://api.stlouisfed.org/fred/series/observations',
params={
'series_id': series_id,
'api_key': FRED_API_KEY,
'file_type': 'json'
}
)
if response.status_code == 200:
observations = response.json().get('observations', [])
dates = [obs['date'] for obs in observations]
# Convert values to float, handling invalid entries
values = [
float(obs['value']) if obs['value'].replace('.', '', 1).isdigit() else float('nan')
for obs in observations
]
data[series_id] = pd.Series(values, index=dates)
else:
print(f"Failed to fetch data for {series_id}")
return pd.DataFrame(data)
# Function to standardize data (z-scores)
def standardize_data(df):
"""
Standardizes each column in the DataFrame to have a mean of 0 and standard deviation of 1.
"""
return (df - df.mean()) / df.std()
# Function to create a responsive 3D correlation matrix
def create_3d_correlation_matrix(df):
"""
Creates a 3D correlation matrix graph using Plotly.
The graph will automatically adjust its size.
"""
correlation_matrix = df.corr()
fig = go.Figure(data=[go.Surface(
z=correlation_matrix.values,
x=correlation_matrix.columns,
y=correlation_matrix.index
)])
fig.update_layout(
title='3D Correlation Matrix (Standardized)',
autosize=True, # Enables auto-resizing
scene=dict(
xaxis=dict(title='Variables'),
yaxis=dict(title='Variables'),
zaxis=dict(title='Correlation')
),
margin=dict(l=0, r=0, t=50, b=50) # Adjust margins for better fit
)
return fig
# Gradio function to handle user interaction
def visualize_correlation(selected_series):
# Map descriptive labels back to FRED series IDs
series_ids = [series for series in series_options if series_options[series] in selected_series]
if not series_ids:
return None, "Please select at least one indicator."
# Fetch and process data
df = fetch_fred_data(series_ids)
if df.empty:
return None, "Failed to fetch data for the selected indicators."
standardized_df = standardize_data(df)
plot = create_3d_correlation_matrix(standardized_df)
return plot, None
# Gradio Blocks Interface
with gr.Blocks() as demo:
gr.Markdown("# 3D Correlation Matrix Visualization with FRED Data")
with gr.Row():
with gr.Column():
series_selector = gr.CheckboxGroup(
choices=list(series_options.values()),
label="Select Economic Indicators",
info="Choose one or more indicators to include in the correlation matrix."
)
submit_button = gr.Button("Generate Matrix")
with gr.Column():
plot_output = gr.Plot(label="3D Correlation Matrix")
error_message = gr.Markdown("", visible=False)
# Event handler for the submit button
submit_button.click(
fn=visualize_correlation,
inputs=[series_selector],
outputs=[plot_output, error_message],
)
# Launch the Gradio app
demo.launch(debug=True)