Spaces:
Runtime error
Runtime error
NaokiOkamoto
commited on
Commit
•
ec74bc1
1
Parent(s):
9ddce62
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
import gradio as gr
|
4 |
+
import datetime
|
5 |
+
import calendar
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import japanize_matplotlib
|
8 |
+
import matplotlib.dates as mdates
|
9 |
+
from dateutil.relativedelta import relativedelta
|
10 |
+
import datetime
|
11 |
+
import datarobot as dr
|
12 |
+
from function import get_fish_qty, get_estat, dr_prediction_deployment, prediction_func, train_modeling
|
13 |
+
|
14 |
+
import yaml
|
15 |
+
with open('config.yaml') as file:
|
16 |
+
config = yaml.safe_load(file.read())
|
17 |
+
|
18 |
+
def retrain():
|
19 |
+
model_management_df = train_modeling.modeling()
|
20 |
+
|
21 |
+
model = dr.Model.get(project = dr.Project.get(model_management_df.iloc[0, :]['model_url'].split('/')[4]),
|
22 |
+
model_id = model_management_df.iloc[0, :]['model_url'].split('/')[-1])
|
23 |
+
feature_impact = pd.DataFrame(model.get_or_request_feature_impact())
|
24 |
+
feature_impact = feature_impact.sort_values('impactNormalized', ascending=False).reset_index(drop=True)
|
25 |
+
feature_impact = feature_impact.iloc[:20, :]
|
26 |
+
for i in range(len(feature_impact)):
|
27 |
+
feature_impact['featureName'][i] = str(i+1).zfill(2) + '_' + feature_impact['featureName'][i]
|
28 |
+
|
29 |
+
return model_management_df.iloc[0, :]['model_type'], model.metrics['RMSE']['holdout'], feature_impact
|
30 |
+
|
31 |
+
|
32 |
+
def get_prediction_result():
|
33 |
+
today = datetime.datetime.now()
|
34 |
+
prediction_month = (today+relativedelta(months=1)).strftime('%Y%m')
|
35 |
+
month_days = month_days = [pd.to_datetime(prediction_month + str(i+1).zfill(2)) for i in range(calendar.monthrange((today+relativedelta(months=1)).year, (today+relativedelta(months=1)).month)[1])]
|
36 |
+
dfc = pd.DataFrame({'target_date':month_days})
|
37 |
+
df = prediction_func.prediction_to_dr(config['oil_price_url'], config['fuel_procurement_cost_url'])
|
38 |
+
df = df.loc[df['target_date'].astype(str).str[:6]==prediction_month]
|
39 |
+
df['target_date'] = pd.to_datetime(df['target_date'].astype(str))
|
40 |
+
df['forecast_point'] = pd.to_datetime(df['forecast_point'].astype(str))
|
41 |
+
df = pd.merge(dfc,
|
42 |
+
df,
|
43 |
+
on='target_date',
|
44 |
+
how='left')
|
45 |
+
df.loc[df['forecast_point'].isnull(), 'forecast_point'] = df['target_date'].apply(lambda x:x-relativedelta(months=1))
|
46 |
+
df = df.loc[~((df['target_date']<(today+relativedelta(months=1)))&(df['電気代'].isnull()))]
|
47 |
+
df = df.rename(columns={'電気代':'電気代_予測'})
|
48 |
+
return df[['forecast_point', 'target_date', '電気代_予測']]
|
49 |
+
|
50 |
+
def plot_prediction_result():
|
51 |
+
update = gr.LinePlot.update(
|
52 |
+
value=get_prediction_result(),
|
53 |
+
x="target_date",
|
54 |
+
y="電気代_予測",
|
55 |
+
title="昨日までの魚の卸売り量から予測された、来月の2人世帯の平均電気料金の推移",
|
56 |
+
width=500,
|
57 |
+
height=300,
|
58 |
+
)
|
59 |
+
return update
|
60 |
+
|
61 |
+
def get_model_infomation():
|
62 |
+
token = 'NjQwMDVmNGI0ZDQzZDFhYzI2YThmZDJiOnVZejljTXFNTXNoUnlKMStoUFhXSFdYMEZRck9lY3dobnEvRFZ1aVBHbVE9'
|
63 |
+
endpoint = 'https://app.datarobot.com/api/v2'
|
64 |
+
dr.Client(
|
65 |
+
endpoint=endpoint,
|
66 |
+
token=token
|
67 |
+
)
|
68 |
+
project = dr.Project.get([i for i in dr.Project.list() if '電気代予測' in str(i)][0].id)
|
69 |
+
|
70 |
+
model_df = pd.DataFrame(
|
71 |
+
[[model.id,
|
72 |
+
model.model_type,
|
73 |
+
model.metrics['RMSE']['validation'],
|
74 |
+
model.metrics['RMSE']['backtesting'],
|
75 |
+
model.metrics['RMSE']['holdout'],
|
76 |
+
model] for model in project.get_datetime_models() if model.model_type != 'Baseline Predictions Using Most Recent Value'],
|
77 |
+
columns=['ID', 'モデル名', 'バックテスト1', '全てのバックテスト', 'holdout', 'model'])
|
78 |
+
model_df = model_df.sort_values('holdout').reset_index(drop=True)
|
79 |
+
|
80 |
+
model = model_df['model'][0]
|
81 |
+
model_info = {}
|
82 |
+
model_info['RMSE'] = model.metrics['RMSE']['holdout']
|
83 |
+
model_info['model_type'] = model.model_type
|
84 |
+
model_info['model_type'] = model.model_type
|
85 |
+
|
86 |
+
feature_impact = pd.DataFrame(model.get_or_request_feature_impact())
|
87 |
+
feature_impact = feature_impact.sort_values('impactNormalized', ascending=False).reset_index(drop=True)
|
88 |
+
feature_impact = feature_impact.iloc[:20, :]
|
89 |
+
|
90 |
+
|
91 |
+
return model_info, feature_impact
|
92 |
+
|
93 |
+
# def get_featuredrift():
|
94 |
+
# deployment = dr.Deployment.get(deployment_id='640d791796a6a52d92c368a0')
|
95 |
+
# target_drift = dr.models.TargetDrift.get(deployment.id)
|
96 |
+
# feature_drift_list = dr.models.FeatureDrift.list(deployment.id)
|
97 |
+
# drift_df = pd.DataFrame(
|
98 |
+
# {
|
99 |
+
# 'feature_name':[target_drift.target_name],
|
100 |
+
# 'drift_score':[target_drift.drift_score],
|
101 |
+
# 'feature_impact':[1]
|
102 |
+
# }
|
103 |
+
# )
|
104 |
+
# drift_df = pd.concat([
|
105 |
+
# drift_df,
|
106 |
+
# pd.DataFrame(
|
107 |
+
# [[
|
108 |
+
# feature_drift.name,
|
109 |
+
# feature_drift.drift_score,
|
110 |
+
# feature_drift.feature_impact
|
111 |
+
# ] for feature_drift in feature_drift_list
|
112 |
+
# ],
|
113 |
+
# columns=[ 'feature_name', 'drift_score', 'feature_impact']
|
114 |
+
# )
|
115 |
+
# ])
|
116 |
+
# start_point = (target_drift.period['start']+relativedelta(hours=9)).strftime("%Y / %m / %d %H:%M:%S")
|
117 |
+
# end_point = (target_drift.period['end']+relativedelta(hours=9)).strftime("%Y / %m / %d %H:%M:%S")
|
118 |
+
|
119 |
+
# return drift_df, start_point, end_point
|
120 |
+
|
121 |
+
with gr.Blocks() as electoric_ploting:
|
122 |
+
gr.Markdown(
|
123 |
+
"""
|
124 |
+
# その日の魚の卸売り量から、来月の家計データ月別支出の電気代を予測するAI
|
125 |
+
使用データ
|
126 |
+
* 東京卸売市場日報
|
127 |
+
* 家計調査の月別支出
|
128 |
+
* 原油価格データ
|
129 |
+
* 燃料調達価格データ
|
130 |
+
why
|
131 |
+
電気代のtrendは原油価格などが大きく影響するが、細かい変化は気候に影響し、気候はある程度海水温に関連性があると考えられる。
|
132 |
+
また、魚の卸売量は水揚げ量に関係し、水揚げ量は海水温に関係するという考えからモデルを作成。
|
133 |
+
"""
|
134 |
+
)
|
135 |
+
with gr.Tab("予測結果"):
|
136 |
+
with gr.Row():
|
137 |
+
with gr.Column():
|
138 |
+
plot = gr.LinePlot(show_label=False)
|
139 |
+
# plot = gr.Plot(label="昨日までの魚の卸売り量から予測された、来月の2人世帯の平均電気料金の推移")
|
140 |
+
with gr.Column():
|
141 |
+
df = get_prediction_result()
|
142 |
+
gr.Textbox(df['電気代_予測'].max(),
|
143 |
+
label='現在までの予測値の最大値')
|
144 |
+
gr.Textbox(df['電気代_予測'].min(),
|
145 |
+
label='現在までの予測値の最小値')
|
146 |
+
gr.Textbox(df['電気代_予測'].mean(),
|
147 |
+
label='現在までの予測値の平均値')
|
148 |
+
gr.Textbox(df['電気代_予測'].median(),
|
149 |
+
label='現在までの予測値の中央値')
|
150 |
+
with gr.Row():
|
151 |
+
gr.DataFrame(get_prediction_result)
|
152 |
+
|
153 |
+
|
154 |
+
with gr.Tab("モデル情報"):
|
155 |
+
gr.Markdown(
|
156 |
+
"""
|
157 |
+
注意:
|
158 |
+
再学習後はモデルのデプロイが自動で行われます。
|
159 |
+
huggingfaceの使用上csvを上書きできないため。
|
160 |
+
"""
|
161 |
+
)
|
162 |
+
retrain_btn= gr.Button(value="再学習")
|
163 |
+
with gr.Row():
|
164 |
+
with gr.Column():
|
165 |
+
model_info, feature_impact_df = get_model_infomation()
|
166 |
+
gr.Textbox(model_info['model_type'], label='現在のモデル')
|
167 |
+
|
168 |
+
with gr.Column():
|
169 |
+
output_model_type = gr.Textbox(label='再学習後のモデル')
|
170 |
+
|
171 |
+
with gr.Row():
|
172 |
+
with gr.Column():
|
173 |
+
gr.Textbox(model_info['RMSE'],label=f'Holdout RMSE精度')
|
174 |
+
with gr.Column():
|
175 |
+
output_acc = gr.Textbox(label='再学習後のHoldout RMSE精度')
|
176 |
+
|
177 |
+
with gr.Row():
|
178 |
+
with gr.Column():
|
179 |
+
for i in range(len(feature_impact_df)):
|
180 |
+
feature_impact_df['featureName'][i] = str(i+1).zfill(2) + '_' + feature_impact_df['featureName'][i]
|
181 |
+
gr.BarPlot(value = feature_impact_df,
|
182 |
+
title = '特徴量インパクト上位20',
|
183 |
+
x = 'featureName',
|
184 |
+
y = 'impactNormalized',
|
185 |
+
tooltip=['impactNormalized'],
|
186 |
+
x_title = '特徴量名',
|
187 |
+
y_title = '特徴量インパクト_相対値',
|
188 |
+
vertical=False,
|
189 |
+
y_lim=[0, 1.2],
|
190 |
+
width=400,
|
191 |
+
height=300)
|
192 |
+
with gr.Column():
|
193 |
+
output_plot = gr.BarPlot(title = '再学習後特徴量インパクト上位20',
|
194 |
+
x = 'featureName',
|
195 |
+
y = 'impactNormalized',
|
196 |
+
tooltip=['impactNormalized'],
|
197 |
+
x_title = '特徴量名',
|
198 |
+
y_title = '特徴量インパクト_相対値',
|
199 |
+
vertical=False,
|
200 |
+
y_lim=[0, 1.2],
|
201 |
+
width=400,
|
202 |
+
height=300)
|
203 |
+
# with gr.Tab("データドリフト情報"):
|
204 |
+
# result = get_featuredrift()
|
205 |
+
# with gr.Row():
|
206 |
+
# gr.Markdown(
|
207 |
+
# """
|
208 |
+
# こちらの図はデータドリフトと特徴量の有用性を表した図になっています。
|
209 |
+
# 味方は以下の通り
|
210 |
+
# * ドリフトスコア:予測データに含まれるデータが、どれぐらい過去のデータに比べてずれが発生しているかを表しており、上に行けば行くほどズレが大きい
|
211 |
+
# * 特徴量の有用性:ターゲットの有用性を1とした時に、どれぐらいそれぞれの特徴量の有用性が高いかを表したもので、右に行くほど有用性が高い
|
212 |
+
# """
|
213 |
+
# )
|
214 |
+
# with gr.Row():
|
215 |
+
# drift_df = result[0]
|
216 |
+
# start_point = result[1]
|
217 |
+
# end_point = result[2]
|
218 |
+
# gr.Textbox(f"{start_point}〜{end_point}",label=f'データドリフト確認期間')
|
219 |
+
# with gr.Row():
|
220 |
+
# if len(drift_df["drift_score"].unique())!=1:
|
221 |
+
# gr.ScatterPlot(
|
222 |
+
# drift_df,
|
223 |
+
# x="feature_impact",
|
224 |
+
# y="drift_score",
|
225 |
+
# title="データドリフトとデータの有用性",
|
226 |
+
# color_legend_title="Species",
|
227 |
+
# x_title="特徴量の有用性",
|
228 |
+
# y_title="ドリフトスコア",
|
229 |
+
# x_lim = [-0.1, drift_df["feature_impact"].max()*1.4],
|
230 |
+
# y_lim = [-0.1, drift_df["drift_score"].max()*1.4],
|
231 |
+
# tooltip=["feature_name", "feature_impact", "drift_score"],
|
232 |
+
# caption="",
|
233 |
+
# height=500,
|
234 |
+
# width=500
|
235 |
+
# )
|
236 |
+
# else:
|
237 |
+
# gr.Markdown(
|
238 |
+
# """
|
239 |
+
# モデルの入れ替え後に予測が実行されていないためdriftは表示できません。
|
240 |
+
# """
|
241 |
+
# )
|
242 |
+
|
243 |
+
retrain_btn.click(retrain, inputs=None, outputs = [output_model_type, output_acc, output_plot])
|
244 |
+
|
245 |
+
electoric_ploting.load(lambda: datetime.datetime.now(),
|
246 |
+
None,
|
247 |
+
# c_time2,
|
248 |
+
every=3600)
|
249 |
+
dep = electoric_ploting.load(plot_prediction_result, None, plot, every=3600)
|
250 |
+
|
251 |
+
electoric_ploting.queue().launch()
|
252 |
+
|
253 |
+
plt.close()
|