Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
"""app.ipynb | |
Automatically generated by Colab. | |
Original file is located at | |
https://colab.research.google.com/drive/1I5SzMTkFWVzSyzaZNkQcYqUi7F9Nzxpx | |
""" | |
import gradio as gr | |
import xarray as xr | |
import pandas as pd | |
from statsmodels.tsa.statespace.sarimax import SARIMAX | |
from sklearn.metrics import mean_squared_error | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from datetime import datetime | |
import fsspec | |
# Filepath for NASA POWER data | |
FILEPATH = 'https://power-analysis-ready-datastore.s3.amazonaws.com/power_901_monthly_radiation_utc.zarr' | |
# Load dataset | |
def load_dataset(filepath): | |
filepath_mapped = fsspec.get_mapper(filepath) | |
return xr.open_zarr(store=filepath_mapped, consolidated=True) | |
dataset = load_dataset(FILEPATH) | |
# Function for SARIMA forecasting | |
def predict_forecast(lat, lon, start_date, end_date, forecast_steps): | |
try: | |
# Find nearest latitude and longitude | |
nearest_lat = dataset.lat.sel(lat=lat, method='nearest').item() | |
nearest_lon = dataset.lon.sel(lon=lon, method='nearest').item() | |
# Slice data for nearest lat/lon and the time range | |
ds_time_series = dataset.ALLSKY_SFC_LW_DWN.sel( | |
lat=nearest_lat, lon=nearest_lon, time=slice(start_date, end_date) | |
).load() | |
data_series = ds_time_series.values | |
dates = pd.to_datetime(ds_time_series.time.values) | |
# Split into training and testing | |
train_size = int(len(data_series) * 0.8) | |
train, test = data_series[:train_size], data_series[train_size:] | |
# Fit SARIMA model | |
model = SARIMAX(train, order=(1, 1, 1), seasonal_order=(1, 1, 1, 12)) | |
result = model.fit(disp=False) | |
forecast = result.forecast(steps=forecast_steps) | |
# Calculate RMSE | |
rmse = np.sqrt(mean_squared_error(test[:len(forecast)], forecast)) | |
# Create a plot | |
plt.figure(figsize=(10, 6)) | |
plt.plot(dates[train_size:train_size + len(test)], test, label='Actual') | |
plt.plot(dates[train_size:train_size + len(forecast)], forecast, label='Forecast', linestyle='--') | |
plt.legend() | |
plt.title(f"Forecast vs Actual (RMSE: {rmse:.2f})") | |
plt.xlabel("Date") | |
plt.ylabel("Solar Radiation") | |
plt.savefig("forecast_plot.png") | |
return { | |
"RMSE": rmse, | |
"Forecast": forecast.tolist(), | |
"Plot": "forecast_plot.png" | |
} | |
except Exception as e: | |
return f"Error: {str(e)}" | |
# Gradio app | |
def gradio_interface(lat, lon, start_date, end_date, forecast_steps): | |
result = predict_forecast(lat, lon, start_date, end_date, forecast_steps) | |
if isinstance(result, dict): # If successful | |
return result["RMSE"], result["Forecast"], result["Plot"] | |
else: # If error | |
return result, None, None | |
interface = gr.Interface( | |
fn=gradio_interface, | |
inputs=[ | |
gr.Number(label="Latitude"), | |
gr.Number(label="Longitude"), | |
gr.Textbox(label="Start Date (YYYY-MM-DD)", placeholder="2019-01-01"), | |
gr.Textbox(label="End Date (YYYY-MM-DD)", placeholder="2024-12-31"), | |
gr.Number(label="Forecast Steps (e.g., 12 for 1 year)") | |
], | |
outputs=[ | |
gr.Textbox(label="RMSE"), | |
gr.Textbox(label="Forecast Values"), | |
gr.Image(label="Forecast Plot") | |
], | |
title="Solar Radiation Forecast with SARIMA", | |
description="Enter location and time range to forecast solar radiation using SARIMA." | |
) | |
# Launch the app | |
interface.launch() |