PowerSage / app.py
yashwantpatilyup's picture
Upload app.py
5d612b2 verified
# -*- coding: utf-8 -*-
"""app.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1I5SzMTkFWVzSyzaZNkQcYqUi7F9Nzxpx
"""
import gradio as gr
import xarray as xr
import pandas as pd
from statsmodels.tsa.statespace.sarimax import SARIMAX
from sklearn.metrics import mean_squared_error
import matplotlib.pyplot as plt
import numpy as np
from datetime import datetime
import fsspec
# Filepath for NASA POWER data
FILEPATH = 'https://power-analysis-ready-datastore.s3.amazonaws.com/power_901_monthly_radiation_utc.zarr'
# Load dataset
def load_dataset(filepath):
filepath_mapped = fsspec.get_mapper(filepath)
return xr.open_zarr(store=filepath_mapped, consolidated=True)
dataset = load_dataset(FILEPATH)
# Function for SARIMA forecasting
def predict_forecast(lat, lon, start_date, end_date, forecast_steps):
try:
# Find nearest latitude and longitude
nearest_lat = dataset.lat.sel(lat=lat, method='nearest').item()
nearest_lon = dataset.lon.sel(lon=lon, method='nearest').item()
# Slice data for nearest lat/lon and the time range
ds_time_series = dataset.ALLSKY_SFC_LW_DWN.sel(
lat=nearest_lat, lon=nearest_lon, time=slice(start_date, end_date)
).load()
data_series = ds_time_series.values
dates = pd.to_datetime(ds_time_series.time.values)
# Split into training and testing
train_size = int(len(data_series) * 0.8)
train, test = data_series[:train_size], data_series[train_size:]
# Fit SARIMA model
model = SARIMAX(train, order=(1, 1, 1), seasonal_order=(1, 1, 1, 12))
result = model.fit(disp=False)
forecast = result.forecast(steps=forecast_steps)
# Calculate RMSE
rmse = np.sqrt(mean_squared_error(test[:len(forecast)], forecast))
# Create a plot
plt.figure(figsize=(10, 6))
plt.plot(dates[train_size:train_size + len(test)], test, label='Actual')
plt.plot(dates[train_size:train_size + len(forecast)], forecast, label='Forecast', linestyle='--')
plt.legend()
plt.title(f"Forecast vs Actual (RMSE: {rmse:.2f})")
plt.xlabel("Date")
plt.ylabel("Solar Radiation")
plt.savefig("forecast_plot.png")
return {
"RMSE": rmse,
"Forecast": forecast.tolist(),
"Plot": "forecast_plot.png"
}
except Exception as e:
return f"Error: {str(e)}"
# Gradio app
def gradio_interface(lat, lon, start_date, end_date, forecast_steps):
result = predict_forecast(lat, lon, start_date, end_date, forecast_steps)
if isinstance(result, dict): # If successful
return result["RMSE"], result["Forecast"], result["Plot"]
else: # If error
return result, None, None
interface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Number(label="Latitude"),
gr.Number(label="Longitude"),
gr.Textbox(label="Start Date (YYYY-MM-DD)", placeholder="2019-01-01"),
gr.Textbox(label="End Date (YYYY-MM-DD)", placeholder="2024-12-31"),
gr.Number(label="Forecast Steps (e.g., 12 for 1 year)")
],
outputs=[
gr.Textbox(label="RMSE"),
gr.Textbox(label="Forecast Values"),
gr.Image(label="Forecast Plot")
],
title="Solar Radiation Forecast with SARIMA",
description="Enter location and time range to forecast solar radiation using SARIMA."
)
# Launch the app
interface.launch()