File size: 8,183 Bytes
d80731b
 
97d673b
d80731b
 
25a3b59
d80731b
 
 
 
 
 
 
 
 
 
 
 
07bdccd
d80731b
 
07bdccd
 
 
 
 
 
 
 
 
cfeae65
 
d80731b
 
 
 
 
 
 
 
 
 
 
 
 
07bdccd
 
d80731b
 
 
 
07bdccd
 
 
 
 
 
 
 
 
 
d80731b
07bdccd
 
 
 
 
 
 
 
d80731b
 
 
 
 
07bdccd
d80731b
07bdccd
 
 
 
 
 
 
d80731b
 
 
 
 
 
 
07bdccd
 
 
 
 
 
 
 
d80731b
 
 
07bdccd
 
d80731b
07bdccd
62b932f
07bdccd
d80731b
62b932f
d80731b
 
07bdccd
62b932f
07bdccd
 
 
62b932f
07bdccd
25a3b59
d80731b
 
07bdccd
 
 
 
 
 
d80731b
 
97d673b
d80731b
 
07bdccd
 
 
97d673b
07bdccd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97d673b
d80731b
 
 
07bdccd
 
 
 
 
 
82d320b
d80731b
 
 
 
07bdccd
 
97d673b
d80731b
07bdccd
d80731b
609c79d
d80731b
97d673b
d80731b
 
 
 
07bdccd
d80731b
97d673b
d80731b
07bdccd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import os
import requests
import pandas as pd
import plotly.graph_objects as go
import gradio as gr

# Set FRED API Key from environment variable
FRED_API_KEY = os.getenv('FRED_API_KEY')

# List of FRED data series and their descriptive labels
series_options = {
    "UNRATE": "Unemployment Rate",
    "GDP": "Gross Domestic Product",
    "CPIAUCSL": "Consumer Price Index for All Urban Consumers",
    "DGS10": "10-Year Treasury Constant Maturity Rate",
    "FEDFUNDS": "Effective Federal Funds Rate",
    "M1SL": "M1 Money Supply",
    "M2SL": "M2 Money Supply",
    # "M3SL": "M3 Money Supply", # M3 data is no longer published by FRED regularly
    "HOUST": "Housing Starts",
    "PCE": "Personal Consumption Expenditures",
    "BAA10YM": "Moody's Baa Corporate Bond Yield Spread",
    "T10Y2Y": "10-Year Minus 2-Year Treasury Yield Spread", # Common recession indicator
    "IPGDIC": "Industrial Production Index (Goods and Structures)",
    "MRTSSM44000USN": "Retail and Food Services Sales", # Monthly Retail Trade
    "CSUSHPINSA": "S&P/Case-Shiller U.S. National Home Price Index",
    "UMCSENT": "University of Michigan: Consumer Sentiment Index",
    "PPIACO": "Producer Price Index: All Commodities",
    "DFF": "Effective Federal Funds Rate", # Alternative to FEDFUNDS, often the same
    "WALCL": "Assets: Total Assets: Total Assets (Less Eliminations from Consolidation): Wednesday Level" # Federal Reserve Balance Sheet
}

# Function to fetch data from the FRED API
def fetch_fred_data(series_ids):
    """
    Fetches data for a list of FRED series IDs.
    Returns a DataFrame with columns as series and dates as rows.
    """
    data = {}
    for series_id in series_ids:
        response = requests.get(
            f'https://api.stlouisfed.org/fred/series/observations',
            params={
                'series_id': series_id,
                'api_key': FRED_API_KEY,
                'file_type': 'json',
                'observation_start': '1970-01-01' # Fetching more historical data for better correlation
            }
        )
        if response.status_code == 200:
            observations = response.json().get('observations', [])
            if observations: # Check if observations are not empty
                dates = [obs['date'] for obs in observations]
                # Convert values to float, handling invalid entries
                values = [
                    float(obs['value']) if obs['value'].replace('.', '', 1).isdigit() else float('nan')
                    for obs in observations
                ]
                data[series_id] = pd.Series(values, index=dates, dtype=float) # Ensure float type
            else:
                print(f"No observations found for {series_id}")
        else:
            print(f"Failed to fetch data for {series_id}. Status code: {response.status_code}")
            print(f"Response: {response.text}")
    
    df = pd.DataFrame(data)
    df.index = pd.to_datetime(df.index) # Convert index to datetime
    df = df.dropna(how='all') # Drop rows where all values are NaN
    df = df.infer_objects(copy=False) # Infer better dtypes
    return df

# Function to standardize data (z-scores)
def standardize_data(df):
    """
    Standardizes each column in the DataFrame to have a mean of 0 and standard deviation of 1.
    Handles columns with constant values or NaNs by replacing with zeros or NaNs respectively.
    """
    standardized_df = pd.DataFrame(index=df.index)
    for col in df.columns:
        if df[col].std() == 0 or df[col].isnull().all():
            standardized_df[col] = 0.0 if df[col].std() == 0 else float('nan') # Set to 0 if constant, NaN if all NaN
        else:
            standardized_df[col] = (df[col] - df[col].mean()) / df[col].std()
    return standardized_df

# Function to create a responsive 3D correlation matrix
def create_3d_correlation_matrix(df):
    """
    Creates a 3D correlation matrix graph using Plotly.
    The graph will automatically adjust its size.
    """
    if df.shape[1] < 2:
        return None, "Please select at least two indicators to compute a correlation matrix."
    
    correlation_matrix = df.corr(numeric_only=True) # Ensure correlation is computed on numeric data
    
    if correlation_matrix.empty:
        return None, "Not enough valid data points to compute correlation."

    fig = go.Figure(data=[go.Surface(
        z=correlation_matrix.values,
        x=correlation_matrix.columns,
        y=correlation_matrix.index,
        colorscale='Viridis' # A good colorscale for data visualization
    )])

    fig.update_layout(
        title='3D Correlation Matrix (Standardized Indicators)',
        autosize=True,  # Enables auto-resizing
        scene=dict(
            xaxis=dict(title='Variables'),
            yaxis=dict(title='Variables'),
            zaxis=dict(title='Correlation', range=[-1, 1]) # Set Z-axis range for clarity
        ),
        margin=dict(l=0, r=0, t=50, b=50),  # Adjust margins for better fit
        hovermode="closest", # Improves hover experience
        height=600 # Set a default height for better initial display, autosize will adjust
    )
    return fig, None

# Gradio function to handle user interaction
def visualize_correlation(selected_series):
    """
    Main function for Gradio interface. Fetches, standardizes, and plots correlation.
    """
    if not selected_series or len(selected_series) < 2:
        return None, "Please select at least two economic indicators to visualize their correlation."

    # Map descriptive labels back to FRED series IDs
    series_ids = [series for series in series_options if series_options[series] in selected_series]

    # Fetch and process data
    df = fetch_fred_data(series_ids)
    
    # Drop columns that are all NaN after fetching
    df_cleaned = df.dropna(axis=1, how='all')

    if df_cleaned.empty or df_cleaned.shape[1] < 2:
        return None, "Not enough valid data available for the selected indicators to compute a meaningful correlation. Please try selecting different indicators or ensure data availability."
    
    # Only keep rows with at least some non-NaN data for the selected series
    df_cleaned = df_cleaned.dropna(how='all')

    # Standardize data, handling potential constant columns
    standardized_df = standardize_data(df_cleaned)

    # Remove columns that became all NaN during standardization (e.g., if original series was all the same value)
    standardized_df = standardized_df.dropna(axis=1, how='all')

    if standardized_df.empty or standardized_df.shape[1] < 2:
         return None, "After standardization, not enough varying data for the selected indicators to compute a meaningful correlation. Please try different indicators."

    plot, error = create_3d_correlation_matrix(standardized_df)
    return plot, error

# Gradio Blocks Interface
with gr.Blocks() as demo:
    gr.Markdown("# 3D Correlation Matrix Visualization with FRED Data")
    gr.Markdown(
        "Explore the correlations between various economic indicators from the Federal Reserve Economic Data (FRED) database. "
        "Select at least two indicators to generate a 3D correlation matrix. "
        "The data is standardized (z-scored) before correlation calculation."
    )
    
    with gr.Row():
        with gr.Column():
            series_selector = gr.CheckboxGroup(
                choices=list(series_options.values()),
                label="Select Economic Indicators",
                info="Choose two or more indicators to include in the correlation matrix. "
                     "Hover over the 3D plot to see correlation values."
            )
            submit_button = gr.Button("Generate Matrix")
        
        with gr.Column():
            plot_output = gr.Plot(label="3D Correlation Matrix") # Ensure interactivity
            error_message = gr.Markdown("", visible=False)

    # Event handler for the submit button
    submit_button.click(
        fn=visualize_correlation,
        inputs=[series_selector],
        outputs=[plot_output, error_message]
    )

# Launch the Gradio app
demo.launch(debug=True) # 'share=True' generates a public link accessible from mobile devices