File size: 4,341 Bytes
25a3b59
 
 
 
 
 
 
 
 
62b932f
cfeae65
 
 
 
 
 
97fe953
 
 
 
 
62b932f
cfeae65
 
62b932f
25a3b59
 
 
 
 
 
 
 
 
62b932f
 
 
 
 
25a3b59
 
8215350
25a3b59
62b932f
8215350
 
 
 
25a3b59
 
 
 
 
 
 
 
 
 
 
 
62b932f
25a3b59
62b932f
 
 
 
25a3b59
62b932f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82d320b
25a3b59
62b932f
cfeae65
 
 
25a3b59
8215350
82d320b
8215350
25a3b59
cfeae65
8215350
82d320b
8215350
25a3b59
82d320b
 
25a3b59
82d320b
 
 
 
 
 
 
 
 
62b932f
82d320b
 
 
 
 
 
 
62b932f
82d320b
 
 
 
 
25a3b59
62b932f
82d320b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
import requests
import pandas as pd
import plotly.graph_objects as go
import gradio as gr

# Set FRED API Key from environment variable
FRED_API_KEY = os.getenv('FRED_API_KEY')

# List of FRED data series and their descriptive labels
series_options = {
    "UNRATE": "Unemployment Rate",
    "GDP": "Gross Domestic Product",
    "CPIAUCSL": "Consumer Price Index for All Urban Consumers",
    "DGS10": "10-Year Treasury Constant Maturity Rate",
    "FEDFUNDS": "Effective Federal Funds Rate",
    "M1SL": "M1 Money Supply",
    "M2SL": "M2 Money Supply",
    "M3SL": "M3 Money Supply",
    "HOUST": "Housing Starts",
    "PCE": "Personal Consumption Expenditures",
    "BAA10YM": "Moody's Baa Corporate Bond Yield Spread"
}

# Function to fetch data from the FRED API
def fetch_fred_data(series_ids):
    """
    Fetches data for a list of FRED series IDs.
    Returns a DataFrame with columns as series and dates as rows.
    """
    data = {}
    for series_id in series_ids:
        response = requests.get(
            f'https://api.stlouisfed.org/fred/series/observations',
            params={
                'series_id': series_id,
                'api_key': FRED_API_KEY,
                'file_type': 'json'
            }
        )
        if response.status_code == 200:
            observations = response.json().get('observations', [])
            dates = [obs['date'] for obs in observations]
            # Convert values to float, handling invalid entries
            values = [
                float(obs['value']) if obs['value'].replace('.', '', 1).isdigit() else float('nan')
                for obs in observations
            ]
            data[series_id] = pd.Series(values, index=dates)
        else:
            print(f"Failed to fetch data for {series_id}")
    return pd.DataFrame(data)

# Function to standardize data (z-scores)
def standardize_data(df):
    """
    Standardizes each column in the DataFrame to have a mean of 0 and standard deviation of 1.
    """
    return (df - df.mean()) / df.std()

# Function to create a responsive 3D correlation matrix
def create_3d_correlation_matrix(df):
    """
    Creates a 3D correlation matrix graph using Plotly.
    The graph will automatically adjust its size.
    """
    correlation_matrix = df.corr()
    fig = go.Figure(data=[go.Surface(
        z=correlation_matrix.values,
        x=correlation_matrix.columns,
        y=correlation_matrix.index
    )])
    fig.update_layout(
        title='3D Correlation Matrix (Standardized)',
        autosize=True,  # Enables auto-resizing
        scene=dict(
            xaxis=dict(title='Variables'),
            yaxis=dict(title='Variables'),
            zaxis=dict(title='Correlation')
        ),
        margin=dict(l=0, r=0, t=50, b=50)  # Adjust margins for better fit
    )
    return fig

# Gradio function to handle user interaction
def visualize_correlation(selected_series):
    # Map descriptive labels back to FRED series IDs
    series_ids = [series for series in series_options if series_options[series] in selected_series]
    
    if not series_ids:
        return None, "Please select at least one indicator."
    
    # Fetch and process data
    df = fetch_fred_data(series_ids)
    if df.empty:
        return None, "Failed to fetch data for the selected indicators."
    
    standardized_df = standardize_data(df)
    plot = create_3d_correlation_matrix(standardized_df)
    return plot, None

# Gradio Blocks Interface
with gr.Blocks() as demo:
    gr.Markdown("# 3D Correlation Matrix Visualization with FRED Data")
    
    with gr.Row():
        with gr.Column():
            series_selector = gr.CheckboxGroup(
                choices=list(series_options.values()),
                label="Select Economic Indicators",
                info="Choose one or more indicators to include in the correlation matrix."
            )
            submit_button = gr.Button("Generate Matrix")
        
        with gr.Column():
            plot_output = gr.Plot(label="3D Correlation Matrix")
            error_message = gr.Markdown("", visible=False)
    
    # Event handler for the submit button
    submit_button.click(
        fn=visualize_correlation,
        inputs=[series_selector],
        outputs=[plot_output, error_message],
    )

# Launch the Gradio app
demo.launch(debug=True)