causalscience commited on
Commit
46b4d92
·
verified ·
1 Parent(s): ee87011

Aug 25 Bug Fixes

Browse files
Files changed (1) hide show
  1. models/granger.py +300 -139
models/granger.py CHANGED
@@ -1,139 +1,300 @@
1
- # causalscience/models/granger.py
2
-
3
- import numpy as np
4
- import pandas as pd
5
- import statsmodels.api as sm
6
- from statsmodels.tsa.stattools import grangercausalitytests, adfuller
7
- import matplotlib.pyplot as plt
8
- from io import BytesIO
9
- from PIL import Image
10
-
11
-
12
- def adf_stationarity_test(series, alpha=0.05):
13
- """
14
- Perform Augmented Dickey-Fuller test to check stationarity.
15
-
16
- Args:
17
- series (pd.Series): Time series data.
18
- alpha (float): Significance level.
19
-
20
- Returns:
21
- p_value (float): p-value of the test.
22
- is_stationary (bool): True if series is stationary.
23
- """
24
- result = adfuller(series.dropna(), autolag='AIC')
25
- p_value = result[1]
26
- return p_value, p_value < alpha
27
-
28
-
29
- def difference_series(series, order=1):
30
- """
31
- Difference the series to make it stationary.
32
- """
33
- return series.diff(periods=order).dropna()
34
-
35
-
36
- def make_data_stationary(df, columns, max_diff=2, alpha=0.05):
37
- """
38
- Iteratively difference columns to achieve stationarity.
39
-
40
- Args:
41
- df (pd.DataFrame): Input DataFrame.
42
- columns (list[str]): Columns to transform.
43
- max_diff (int): Maximum differencing order.
44
- alpha (float): Significance level.
45
-
46
- Returns:
47
- df_out (pd.DataFrame): Transformed DataFrame.
48
- transformations (dict): Info on differencing orders.
49
- """
50
- df_out = df.copy()
51
- transformations = {}
52
- for col in columns:
53
- pval, stationary = adf_stationarity_test(df_out[col], alpha)
54
- if stationary:
55
- transformations[col] = f"Already stationary (p={pval:.4f})"
56
- continue
57
- for d in range(1, max_diff + 1):
58
- diff_series = difference_series(df_out[col], order=d)
59
- pval_d, stat_d = adf_stationarity_test(diff_series, alpha)
60
- if stat_d:
61
- df_out[col] = diff_series
62
- transformations[col] = f"Differenced order {d} (p={pval_d:.4f})"
63
- break
64
- else:
65
- transformations[col] = f"No stationarity by {max_diff} diffs"
66
- df_out = df_out.dropna()
67
- return df_out, transformations
68
-
69
-
70
- def recommend_var_lag(df, maxlags=7, criterion='aic'):
71
- """
72
- Recommend lag order for VAR model by information criterion.
73
- """
74
- model = sm.tsa.VAR(df.dropna())
75
- results = model.select_order(maxlags=maxlags)
76
- return results.selected_orders.get(criterion)
77
-
78
-
79
- def run_granger_analysis(df, max_lags=7, criterion='aic', apply_transformation=False,
80
- columns_to_transform=None, max_diff=2, alpha=0.05):
81
- """
82
- Run Granger causality analysis between first two columns.
83
-
84
- Args:
85
- df (pd.DataFrame): Time series DataFrame.
86
- max_lags (int): Max lags to test.
87
- criterion (str): Criterion for lag selection.
88
- apply_transformation (bool): Whether to difference to stationarity.
89
- columns_to_transform (list[str]): Columns to transform.
90
- max_diff (int): Max differencing.
91
- alpha (float): Significance level.
92
-
93
- Returns:
94
- summary_text (str): Text output of tests.
95
- plot_img (PIL.Image.Image): Time series plot.
96
- transformed_csv (str or None): Path to stationary data CSV.
97
- """
98
- transformed_csv = None
99
- df_work = df.copy()
100
- if apply_transformation and columns_to_transform:
101
- df_work, trans = make_data_stationary(df_work, columns_to_transform, max_diff, alpha)
102
- transformed_csv = 'transformed_data.csv'
103
- df_work.to_csv(transformed_csv, index=False)
104
- if df_work.shape[1] < 2:
105
- raise ValueError("Need at least two series for Granger causality")
106
-
107
- lag_order = recommend_var_lag(df_work, maxlags=max_lags, criterion=criterion) or min(max_lags, 1)
108
- data_test = df_work.iloc[:, :2]
109
- test_output = []
110
- def _capture():
111
- import sys, io
112
- buf = io.StringIO()
113
- sys_stdout = sys.stdout
114
- try:
115
- sys.stdout = buf
116
- grangercausalitytests(data_test, maxlag=lag_order, verbose=True)
117
- finally:
118
- sys.stdout = sys_stdout
119
- return buf.getvalue()
120
- summary_text = f"Recommended Lag: {lag_order}\n" + _capture()
121
-
122
- # Plot
123
- fig, ax1 = plt.subplots()
124
- ax2 = ax1.twinx()
125
- col1, col2 = data_test.columns[:2]
126
- ax1.plot(data_test[col1], label=col1)
127
- ax2.plot(data_test[col2], label=col2)
128
- ax1.set_xlabel('Time')
129
- ax1.set_ylabel(col1)
130
- ax2.set_ylabel(col2)
131
- ax1.legend(loc='upper left')
132
- ax2.legend(loc='upper right')
133
- buf = BytesIO()
134
- fig.savefig(buf, format='png', bbox_inches='tight')
135
- plt.close(fig)
136
- buf.seek(0)
137
- plot_img = Image.open(buf)
138
-
139
- return summary_text, plot_img, transformed_csv
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # causalscience/models/granger.py
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ import statsmodels.api as sm
6
+ from statsmodels.tsa.stattools import grangercausalitytests, adfuller
7
+ import matplotlib.pyplot as plt
8
+ from io import BytesIO
9
+ from PIL import Image
10
+
11
+
12
+ def adf_stationarity_test(series, alpha=0.05):
13
+ """
14
+ Perform Augmented Dickey-Fuller test to check stationarity.
15
+
16
+ Args:
17
+ series (pd.Series): Time series data.
18
+ alpha (float): Significance level.
19
+
20
+ Returns:
21
+ p_value (float): p-value of the test.
22
+ is_stationary (bool): True if series is stationary.
23
+ """
24
+ result = adfuller(series.dropna(), autolag='AIC')
25
+ p_value = result[1]
26
+ return p_value, p_value < alpha
27
+
28
+
29
+ def difference_series(series, order=1):
30
+ """
31
+ Difference the series to make it stationary.
32
+ """
33
+ return series.diff(periods=order).dropna()
34
+
35
+
36
+ def make_data_stationary(df, columns_to_transform, max_diff=2, alpha=0.05):
37
+ """
38
+ Iteratively difference specified columns to achieve stationarity.
39
+
40
+ Args:
41
+ df (pd.DataFrame): Input DataFrame.
42
+ columns_to_transform (list[str] or None): Columns to transform. If None, no transformation is done.
43
+ max_diff (int): Maximum differencing order.
44
+ alpha (float): Significance level.
45
+
46
+ Returns:
47
+ df_out (pd.DataFrame): Transformed DataFrame.
48
+ transformations (dict): Info on differencing orders for transformed columns.
49
+ """
50
+ df_out = df.copy()
51
+ transformations = {}
52
+
53
+ if not columns_to_transform: # If no columns are specified for transformation
54
+ return df_out, transformations
55
+
56
+ for col in columns_to_transform:
57
+ if col not in df_out.columns:
58
+ transformations[col] = "Error: Column not found in DataFrame."
59
+ continue # Skip to the next column
60
+
61
+ # Ensure the column data is numeric for ADF test
62
+ if not pd.api.types.is_numeric_dtype(df_out[col]):
63
+ transformations[col] = "Skipped: Non-numeric data."
64
+ continue
65
+
66
+ pval, stationary = adf_stationarity_test(df_out[col], alpha)
67
+ if stationary:
68
+ transformations[col] = f"Already stationary (p={pval:.4f})"
69
+ continue
70
+
71
+ original_series_for_diff = df[col].copy() # Always difference from the original non-differenced series
72
+
73
+ for d in range(1, max_diff + 1):
74
+ # Apply differencing iteratively on the original series for this column
75
+ diff_series_data = difference_series(original_series_for_diff, order=d)
76
+
77
+ if diff_series_data.empty:
78
+ transformations[col] = f"Differenced order {d} resulted in empty series. Stationarity not achieved."
79
+ # df_out[col] = diff_series_data # Store the empty series if desired, or original
80
+ break # Stop differencing this column
81
+
82
+ pval_d, stat_d = adf_stationarity_test(diff_series_data, alpha)
83
+ if stat_d:
84
+ df_out[col] = diff_series_data # Update the column in df_out with the stationary series
85
+ transformations[col] = f"Differenced order {d} (p={pval_d:.4f})"
86
+ break
87
+ else: # This else belongs to the for loop (if break was not hit)
88
+ transformations[col] = f"Stationarity not achieved within {max_diff} differencing orders (last p={pval_d:.4f}). Using last differenced series."
89
+ # df_out[col] = diff_series_data # Ensure the last differenced series is used even if not stationary
90
+
91
+ # df_out = df_out.dropna() # Dropna AFTER all transformations are applied if needed,
92
+ # but this might heavily reduce data if different columns have different diff orders.
93
+ # It's often better to let VAR handle NaNs or for user to decide.
94
+ # For Granger, two series are selected *after* this.
95
+ return df_out, transformations
96
+
97
+
98
+ def recommend_var_lag(df, maxlags=7, criterion='aic'):
99
+ """
100
+ Recommend lag order for VAR model by information criterion.
101
+ Assumes df contains only the series for VAR (typically 2 for Granger).
102
+ """
103
+ # Ensure no NaNs are passed to VAR model, common after differencing
104
+ df_dropna = df.dropna()
105
+ if df_dropna.shape[0] < maxlags + 1: # Check if enough data points after dropping NaNs
106
+ # Not enough data to reliably select lag, or even fit model
107
+ # Default to a small lag, or raise error.
108
+ # print(f"Warning: Not enough data points ({df_dropna.shape[0]}) after dropping NaNs for maxlags={maxlags}. Defaulting lag to 1.")
109
+ return 1
110
+ if df_dropna.empty:
111
+ # print("Warning: DataFrame is empty after dropping NaNs. Cannot recommend VAR lag. Defaulting to 1.")
112
+ return 1
113
+
114
+
115
+ model = sm.tsa.VAR(df_dropna)
116
+ try:
117
+ results = model.select_order(maxlags=maxlags)
118
+ selected_lag = results.selected_orders.get(criterion)
119
+ return selected_lag if selected_lag is not None else 1 # Default to 1 if criterion not found or is None
120
+ except Exception as e:
121
+ # print(f"Error during VAR lag selection: {e}. Defaulting lag to 1.")
122
+ return 1
123
+
124
+
125
+ def run_granger_analysis(df, target_col1, target_col2, max_lags=7, criterion='aic',
126
+ apply_transformation=False, columns_to_transform=None,
127
+ max_diff=2, alpha=0.05):
128
+ """
129
+ Run Granger causality analysis between two specified columns.
130
+
131
+ Args:
132
+ df (pd.DataFrame): Time series DataFrame.
133
+ target_col1 (str): Name of the first column for Granger analysis.
134
+ target_col2 (str): Name of the second column for Granger analysis.
135
+ max_lags (int): Max lags to test.
136
+ criterion (str): Criterion for lag selection.
137
+ apply_transformation (bool): Whether to difference to stationarity.
138
+ columns_to_transform (list[str] or None): Specific columns to attempt to make stationary.
139
+ If None, and apply_transformation is True, it could
140
+ default to target_col1 and target_col2, or be an error.
141
+ Best to be explicit from UI.
142
+ max_diff (int): Max differencing.
143
+ alpha (float): Significance level for ADF test.
144
+
145
+ Returns:
146
+ summary_text (str): Text output of tests.
147
+ plot_img (PIL.Image.Image or None): Time series plot of the two target series.
148
+ transformed_csv_path (str or None): Path to CSV of the (potentially transformed) DataFrame.
149
+ transformation_info (dict): Log of transformations applied.
150
+ """
151
+ if not target_col1 or not target_col2:
152
+ raise ValueError("target_col1 and target_col2 must be specified.")
153
+ if target_col1 not in df.columns or target_col2 not in df.columns:
154
+ raise ValueError(f"One or both target columns ('{target_col1}', '{target_col2}') not found in DataFrame columns: {list(df.columns)}")
155
+ if target_col1 == target_col2:
156
+ raise ValueError("target_col1 and target_col2 must be different.")
157
+
158
+ df_work = df.copy()
159
+ transformation_info = {}
160
+ transformed_csv_path = None
161
+
162
+ if apply_transformation:
163
+ # If columns_to_transform is not provided, default to transforming the target columns
164
+ # This is a design choice; an alternative would be to raise an error or transform all numeric columns.
165
+ cols_for_stat = columns_to_transform
166
+ if not cols_for_stat: # If an empty list or None was passed, and apply_transformation is True
167
+ cols_for_stat = [target_col1, target_col2]
168
+ transformation_info["Note"] = f"No specific columns for transformation provided; applying to target series: {target_col1}, {target_col2}."
169
+
170
+ # Ensure only existing columns are in cols_for_stat, especially if user-provided
171
+ valid_cols_for_stat = [col for col in cols_for_stat if col in df_work.columns]
172
+ if len(valid_cols_for_stat) < len(cols_for_stat):
173
+ missing = set(cols_for_stat) - set(valid_cols_for_stat)
174
+ transformation_info["Warning_Transformation"] = f"Columns not found for stationarity transformation and were skipped: {list(missing)}"
175
+
176
+ if valid_cols_for_stat:
177
+ df_work, trans_log = make_data_stationary(df_work, valid_cols_for_stat, max_diff, alpha)
178
+ transformation_info.update(trans_log) # Merge the detailed log
179
+ transformed_csv_path = 'transformed_data.csv'
180
+ try:
181
+ df_work.to_csv(transformed_csv_path, index=False)
182
+ except Exception as e:
183
+ transformation_info["CSV_Save_Error"] = f"Could not save transformed data: {str(e)}"
184
+ transformed_csv_path = None # Indicate saving failed
185
+ else:
186
+ transformation_info["Note_Transformation"] = "No valid columns found or specified for stationarity transformation."
187
+
188
+
189
+ # Select the two target series for Granger causality AFTER potential transformations
190
+ # Ensure they still exist and are numeric
191
+ if target_col1 not in df_work.columns or target_col2 not in df_work.columns:
192
+ # This could happen if differencing made a column all NaN and it got dropped by some operation
193
+ raise ValueError(f"Target columns '{target_col1}' or '{target_col2}' are no longer in the DataFrame after transformations. Check transformation log.")
194
+
195
+ series1_data = df_work[target_col1]
196
+ series2_data = df_work[target_col2]
197
+
198
+ if not pd.api.types.is_numeric_dtype(series1_data) or \
199
+ not pd.api.types.is_numeric_dtype(series2_data):
200
+ raise ValueError(f"Target columns '{target_col1}' and/or '{target_col2}' must be numeric for Granger causality. Check data types after transformations.")
201
+
202
+ data_for_test = pd.DataFrame({target_col1: series1_data, target_col2: series2_data})
203
+
204
+ # For VAR lag selection, we need to drop NaNs from the pair of series
205
+ data_for_var_lag_selection = data_for_test.dropna()
206
+
207
+ if data_for_var_lag_selection.shape[0] < max_lags + 1 : # Check if enough data points
208
+ # This check is now more critical as differencing can reduce data significantly
209
+ transformation_info["LagSelectionWarning"] = (
210
+ f"Not enough non-NaN data points ({data_for_var_lag_selection.shape[0]}) "
211
+ f"in '{target_col1}' & '{target_col2}' pair for maxlags={max_lags} after transformations/NaN removal. "
212
+ f"Defaulting lag to 1 or minimum possible."
213
+ )
214
+ # Attempt to determine a lag with available data, or default to 1
215
+ effective_max_lags = min(max_lags, max(1, data_for_var_lag_selection.shape[0] // 3 -1) ) # Heuristic
216
+ if effective_max_lags < 1 : effective_max_lags =1
217
+
218
+ lag_order = recommend_var_lag(data_for_var_lag_selection, maxlags=effective_max_lags, criterion=criterion)
219
+ lag_order = lag_order or 1 # Ensure lag_order is at least 1
220
+ elif data_for_var_lag_selection.empty:
221
+ transformation_info["LagSelectionError"] = (
222
+ f"DataFrame for VAR lag selection between '{target_col1}' & '{target_col2}' is empty after transformations/NaN removal. "
223
+ "Cannot perform Granger causality. Defaulting lag to 1 for report."
224
+ )
225
+ lag_order = 1
226
+ else:
227
+ lag_order = recommend_var_lag(data_for_var_lag_selection, maxlags=max_lags, criterion=criterion)
228
+ lag_order = lag_order or min(max_lags, 1) # Ensure lag_order is at least 1 if recommend_var_lag returns None
229
+
230
+ granger_input_df = data_for_test[[target_col1, target_col2]] # Ensure correct order for interpretation
231
+
232
+ summary_text_parts = [f"Granger Causality Analysis for '{target_col1}' and '{target_col2}'"]
233
+ summary_text_parts.append(f"Recommended Lag Order (based on VAR on processed series, criterion: {criterion}): {lag_order}\n")
234
+
235
+ if granger_input_df.dropna().shape[0] < lag_order + 1:
236
+ summary_text_parts.append(
237
+ f"Critical Warning: After NaN removal for the pair ('{target_col1}', '{target_col2}'), "
238
+ f"only {granger_input_df.dropna().shape[0]} observations remain. "
239
+ f"This may be insufficient for Granger causality tests with lag {lag_order}. Results might be unreliable or fail.\n"
240
+ )
241
+ # Return early or let grangercausalitytests try and fail
242
+ # For now, let it try, it might still work for lag 1 if data is very short.
243
+
244
+ # Capture output from grangercausalitytests
245
+ capture_buffer = None
246
+ try:
247
+ import sys, io
248
+ original_stdout = sys.stdout
249
+ sys.stdout = capture_buffer = io.StringIO()
250
+ # Note: grangercausalitytests expects the first column to be the Y (effect)
251
+ # and the second to be X (cause) for the test "X -> Y".
252
+ # The function runs tests for both directions.
253
+ # The order here (target_col1, then target_col2) means the first set of tests is "target_col2 -> target_col1"
254
+ grangercausalitytests(granger_input_df, maxlag=lag_order, verbose=True)
255
+ except Exception as e:
256
+ summary_text_parts.append(f"Error during grangercausalitytests execution: {str(e)}\n")
257
+ finally:
258
+ if capture_buffer:
259
+ summary_text_parts.append(capture_buffer.getvalue())
260
+ sys.stdout = original_stdout # Restore stdout
261
+
262
+ summary_text = "\n".join(summary_text_parts)
263
+
264
+ # Plot the two selected series (potentially transformed)
265
+ plot_img = None
266
+ try:
267
+ fig, ax1 = plt.subplots()
268
+ ax2 = ax1.twinx()
269
+
270
+ # Use .dropna() for plotting to avoid issues if leading/trailing NaNs exist from differencing
271
+ # but only for the series being plotted, not changing data_for_test
272
+ plot_series1 = data_for_test[target_col1].dropna()
273
+ plot_series2 = data_for_test[target_col2].dropna()
274
+
275
+ ax1.plot(plot_series1.index, plot_series1, label=target_col1, color='blue')
276
+ ax2.plot(plot_series2.index, plot_series2, label=target_col2, color='red')
277
+
278
+ ax1.set_xlabel('Time / Index')
279
+ ax1.set_ylabel(target_col1, color='blue')
280
+ ax2.set_ylabel(target_col2, color='red')
281
+
282
+ # Combine legends
283
+ lines1, labels1 = ax1.get_legend_handles_labels()
284
+ lines2, labels2 = ax2.get_legend_handles_labels()
285
+ ax2.legend(lines1 + lines2, labels1 + labels2, loc='upper left')
286
+
287
+ plt.title(f"Time Series: {target_col1} vs {target_col2}")
288
+ fig.tight_layout() # Adjust layout
289
+
290
+ buf = BytesIO()
291
+ fig.savefig(buf, format='png', bbox_inches='tight')
292
+ plt.close(fig)
293
+ buf.seek(0)
294
+ plot_img = Image.open(buf)
295
+ except Exception as e:
296
+ # Add to transformation_info as it's a non-critical error for the text output
297
+ transformation_info["Plotting_Error"] = f"Could not generate plot: {str(e)}"
298
+
299
+
300
+ return summary_text, plot_img, transformed_csv_path, transformation_info