magcheong's picture
Update app.py
df7f764 verified
raw
history blame contribute delete
No virus
2.92 kB
import numpy as np
import pandas as pd
import pickle
import datasets
from datasets import load_dataset, Dataset, DatasetDict
from torch.utils.data import DataLoader
from six import BytesIO
import gradio as gr
from huggingface_hub import snapshot_download
import os
from matplotlib import pyplot as plt
import seaborn as sns
import torch
from neuralforecast import NeuralForecast
from neuralforecast.models import NBEATS
REPO_ID = "magcheong/ITI110_Energy_Prediction"
download_dir = snapshot_download(REPO_ID)
test_file = pd.read_csv(download_dir +"/futr_df.csv")
test = pd.DataFrame(test_file)
test['ds'] = pd.to_datetime(test["ds"], format='%Y-%m-%d %H:%M:%S')
def load_model():
download_dir = snapshot_download(REPO_ID)
saved_model_dir = os.path.join(download_dir, "saved_model")
prediction_model = NeuralForecast.load(saved_model_dir)
return prediction_model
prediction_model = load_model()
def UserInputDays(number_of_days_ahead):
row_count = int(number_of_days_ahead)*24
return row_count
def UserInputTime(hour_of_the_day):
if hour_of_the_day == '':
hour_of_the_day = 0
return int(hour_of_the_day)
# First day of test is 2014-02-21 01:00:00
def predict2(row):
prediction = prediction_model.predict(futr_df=test).reset_index()
# predict_dict = prediction[['ds', 'NBEATS']].to_dict()
# predicted_value = prediction['NBEATS'][row]
return round(prediction.iloc[row]['NBEATS'], 3)
def predict(number_of_days_ahead, hour_of_the_day):
row = UserInputDays(number_of_days_ahead) + UserInputTime(hour_of_the_day)
return predict2(row)
# def day_average(number_of_days_ahead):
# n = UserInputDays(number_of_days_ahead)
# if n == 0:
# return round(prediction.iloc[:23]['NBEATS'].mean(), 3)
# else:
# start_n = 23 + 24*(n-1)
# return round(prediction.iloc[start_n:start_n+24]['NBEATS'].mean(), 3)
title = "ITI110 Energy Prediction"
description = "This is an app to predict energy consumption in London over the next 7 days."
css_code='div {margin-left: auto; margin-right: auto; width: 100%;\
background-image: url("https://drive.google.com/file/d/1MZmYop1st_lAuDbvKIuTbjxj-xjPtwX-/view?usp=sharing"); repeat 0 0;}'
gr.Interface(predict,
title = title,
description = description,
# css=css_code,
inputs=[gr.Slider(0,6,1, step= 1, label='Select number of days ahead.'), gr.Slider(0,23,1, step= 1, label='Select hour of the day.')],
outputs=gr.Textbox(label='Predicted energy consumption at selected hour:'),
theme = 'finlaymacklon/smooth_slate'
).launch(share=True)
# gr.Interface(predict,
# title = title,
# description = description,
# css=css_code,
# inputs=["textbox","textbox"],
# outputs="textbox"
# ).launch(share=True)