NaokiOkamoto commited on
Commit
ec74bc1
1 Parent(s): 9ddce62

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +253 -0
app.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import gradio as gr
4
+ import datetime
5
+ import calendar
6
+ import matplotlib.pyplot as plt
7
+ import japanize_matplotlib
8
+ import matplotlib.dates as mdates
9
+ from dateutil.relativedelta import relativedelta
10
+ import datetime
11
+ import datarobot as dr
12
+ from function import get_fish_qty, get_estat, dr_prediction_deployment, prediction_func, train_modeling
13
+
14
+ import yaml
15
+ with open('config.yaml') as file:
16
+ config = yaml.safe_load(file.read())
17
+
18
+ def retrain():
19
+ model_management_df = train_modeling.modeling()
20
+
21
+ model = dr.Model.get(project = dr.Project.get(model_management_df.iloc[0, :]['model_url'].split('/')[4]),
22
+ model_id = model_management_df.iloc[0, :]['model_url'].split('/')[-1])
23
+ feature_impact = pd.DataFrame(model.get_or_request_feature_impact())
24
+ feature_impact = feature_impact.sort_values('impactNormalized', ascending=False).reset_index(drop=True)
25
+ feature_impact = feature_impact.iloc[:20, :]
26
+ for i in range(len(feature_impact)):
27
+ feature_impact['featureName'][i] = str(i+1).zfill(2) + '_' + feature_impact['featureName'][i]
28
+
29
+ return model_management_df.iloc[0, :]['model_type'], model.metrics['RMSE']['holdout'], feature_impact
30
+
31
+
32
+ def get_prediction_result():
33
+ today = datetime.datetime.now()
34
+ prediction_month = (today+relativedelta(months=1)).strftime('%Y%m')
35
+ month_days = month_days = [pd.to_datetime(prediction_month + str(i+1).zfill(2)) for i in range(calendar.monthrange((today+relativedelta(months=1)).year, (today+relativedelta(months=1)).month)[1])]
36
+ dfc = pd.DataFrame({'target_date':month_days})
37
+ df = prediction_func.prediction_to_dr(config['oil_price_url'], config['fuel_procurement_cost_url'])
38
+ df = df.loc[df['target_date'].astype(str).str[:6]==prediction_month]
39
+ df['target_date'] = pd.to_datetime(df['target_date'].astype(str))
40
+ df['forecast_point'] = pd.to_datetime(df['forecast_point'].astype(str))
41
+ df = pd.merge(dfc,
42
+ df,
43
+ on='target_date',
44
+ how='left')
45
+ df.loc[df['forecast_point'].isnull(), 'forecast_point'] = df['target_date'].apply(lambda x:x-relativedelta(months=1))
46
+ df = df.loc[~((df['target_date']<(today+relativedelta(months=1)))&(df['電気代'].isnull()))]
47
+ df = df.rename(columns={'電気代':'電気代_予測'})
48
+ return df[['forecast_point', 'target_date', '電気代_予測']]
49
+
50
+ def plot_prediction_result():
51
+ update = gr.LinePlot.update(
52
+ value=get_prediction_result(),
53
+ x="target_date",
54
+ y="電気代_予測",
55
+ title="昨日までの魚の卸売り量から予測された、来月の2人世帯の平均電気料金の推移",
56
+ width=500,
57
+ height=300,
58
+ )
59
+ return update
60
+
61
+ def get_model_infomation():
62
+ token = 'NjQwMDVmNGI0ZDQzZDFhYzI2YThmZDJiOnVZejljTXFNTXNoUnlKMStoUFhXSFdYMEZRck9lY3dobnEvRFZ1aVBHbVE9'
63
+ endpoint = 'https://app.datarobot.com/api/v2'
64
+ dr.Client(
65
+ endpoint=endpoint,
66
+ token=token
67
+ )
68
+ project = dr.Project.get([i for i in dr.Project.list() if '電気代予測' in str(i)][0].id)
69
+
70
+ model_df = pd.DataFrame(
71
+ [[model.id,
72
+ model.model_type,
73
+ model.metrics['RMSE']['validation'],
74
+ model.metrics['RMSE']['backtesting'],
75
+ model.metrics['RMSE']['holdout'],
76
+ model] for model in project.get_datetime_models() if model.model_type != 'Baseline Predictions Using Most Recent Value'],
77
+ columns=['ID', 'モデル名', 'バックテスト1', '全てのバックテスト', 'holdout', 'model'])
78
+ model_df = model_df.sort_values('holdout').reset_index(drop=True)
79
+
80
+ model = model_df['model'][0]
81
+ model_info = {}
82
+ model_info['RMSE'] = model.metrics['RMSE']['holdout']
83
+ model_info['model_type'] = model.model_type
84
+ model_info['model_type'] = model.model_type
85
+
86
+ feature_impact = pd.DataFrame(model.get_or_request_feature_impact())
87
+ feature_impact = feature_impact.sort_values('impactNormalized', ascending=False).reset_index(drop=True)
88
+ feature_impact = feature_impact.iloc[:20, :]
89
+
90
+
91
+ return model_info, feature_impact
92
+
93
+ # def get_featuredrift():
94
+ # deployment = dr.Deployment.get(deployment_id='640d791796a6a52d92c368a0')
95
+ # target_drift = dr.models.TargetDrift.get(deployment.id)
96
+ # feature_drift_list = dr.models.FeatureDrift.list(deployment.id)
97
+ # drift_df = pd.DataFrame(
98
+ # {
99
+ # 'feature_name':[target_drift.target_name],
100
+ # 'drift_score':[target_drift.drift_score],
101
+ # 'feature_impact':[1]
102
+ # }
103
+ # )
104
+ # drift_df = pd.concat([
105
+ # drift_df,
106
+ # pd.DataFrame(
107
+ # [[
108
+ # feature_drift.name,
109
+ # feature_drift.drift_score,
110
+ # feature_drift.feature_impact
111
+ # ] for feature_drift in feature_drift_list
112
+ # ],
113
+ # columns=[ 'feature_name', 'drift_score', 'feature_impact']
114
+ # )
115
+ # ])
116
+ # start_point = (target_drift.period['start']+relativedelta(hours=9)).strftime("%Y / %m / %d %H:%M:%S")
117
+ # end_point = (target_drift.period['end']+relativedelta(hours=9)).strftime("%Y / %m / %d %H:%M:%S")
118
+
119
+ # return drift_df, start_point, end_point
120
+
121
+ with gr.Blocks() as electoric_ploting:
122
+ gr.Markdown(
123
+ """
124
+ # その日の魚の卸売り量から、来月の家計データ月別支出の電気代を予測するAI
125
+ 使用データ
126
+ * 東京卸売市場日報
127
+ * 家計調査の月別支出
128
+ * 原油価格データ
129
+ * 燃料調達価格データ
130
+ why
131
+ 電気代のtrendは原油価格などが大きく影響するが、細かい変化は気候に影響し、気候はある程度海水温に関連性があると考えられる。
132
+ また、魚の卸売量は水揚げ量に関係し、水揚げ量は海水温に関係するという考えからモデルを作成。
133
+ """
134
+ )
135
+ with gr.Tab("予測結果"):
136
+ with gr.Row():
137
+ with gr.Column():
138
+ plot = gr.LinePlot(show_label=False)
139
+ # plot = gr.Plot(label="昨日までの魚の卸売り量から予測された、来月の2人世帯の平均電気料金の推移")
140
+ with gr.Column():
141
+ df = get_prediction_result()
142
+ gr.Textbox(df['電気代_予測'].max(),
143
+ label='現在までの予測値の最大値')
144
+ gr.Textbox(df['電気代_予測'].min(),
145
+ label='現在までの予測値の最小値')
146
+ gr.Textbox(df['電気代_予測'].mean(),
147
+ label='現在までの予測値の平均値')
148
+ gr.Textbox(df['電気代_予測'].median(),
149
+ label='現在までの予測値の中央値')
150
+ with gr.Row():
151
+ gr.DataFrame(get_prediction_result)
152
+
153
+
154
+ with gr.Tab("モデル情報"):
155
+ gr.Markdown(
156
+ """
157
+ 注意:
158
+ 再学習後はモデルのデプロイが自動で行われます。
159
+ huggingfaceの使用上csvを上書きできないため。
160
+ """
161
+ )
162
+ retrain_btn= gr.Button(value="再学習")
163
+ with gr.Row():
164
+ with gr.Column():
165
+ model_info, feature_impact_df = get_model_infomation()
166
+ gr.Textbox(model_info['model_type'], label='現在のモデル')
167
+
168
+ with gr.Column():
169
+ output_model_type = gr.Textbox(label='再学習後のモデル')
170
+
171
+ with gr.Row():
172
+ with gr.Column():
173
+ gr.Textbox(model_info['RMSE'],label=f'Holdout RMSE精度')
174
+ with gr.Column():
175
+ output_acc = gr.Textbox(label='再学習後のHoldout RMSE精度')
176
+
177
+ with gr.Row():
178
+ with gr.Column():
179
+ for i in range(len(feature_impact_df)):
180
+ feature_impact_df['featureName'][i] = str(i+1).zfill(2) + '_' + feature_impact_df['featureName'][i]
181
+ gr.BarPlot(value = feature_impact_df,
182
+ title = '特徴量インパクト上位20',
183
+ x = 'featureName',
184
+ y = 'impactNormalized',
185
+ tooltip=['impactNormalized'],
186
+ x_title = '特徴量名',
187
+ y_title = '特徴量インパクト_相対値',
188
+ vertical=False,
189
+ y_lim=[0, 1.2],
190
+ width=400,
191
+ height=300)
192
+ with gr.Column():
193
+ output_plot = gr.BarPlot(title = '再学習後特徴量インパクト上位20',
194
+ x = 'featureName',
195
+ y = 'impactNormalized',
196
+ tooltip=['impactNormalized'],
197
+ x_title = '特徴量名',
198
+ y_title = '特徴量インパクト_相対値',
199
+ vertical=False,
200
+ y_lim=[0, 1.2],
201
+ width=400,
202
+ height=300)
203
+ # with gr.Tab("データドリフト情報"):
204
+ # result = get_featuredrift()
205
+ # with gr.Row():
206
+ # gr.Markdown(
207
+ # """
208
+ # こちらの図はデータドリフトと特徴量の有用性を表した図になっています。
209
+ # 味方は以下の通り
210
+ # * ドリフトスコア:予測データに含まれるデータが、どれぐらい過去のデータに比べてずれが発生しているかを表しており、上に行けば行くほどズレが大きい
211
+ # * 特徴量の有用性:ターゲットの有用性を1とした時に、どれぐらいそれぞれの特徴量の有用性が高いかを表したもので、右に行くほど有用性が高い
212
+ # """
213
+ # )
214
+ # with gr.Row():
215
+ # drift_df = result[0]
216
+ # start_point = result[1]
217
+ # end_point = result[2]
218
+ # gr.Textbox(f"{start_point}〜{end_point}",label=f'データドリフト確認期間')
219
+ # with gr.Row():
220
+ # if len(drift_df["drift_score"].unique())!=1:
221
+ # gr.ScatterPlot(
222
+ # drift_df,
223
+ # x="feature_impact",
224
+ # y="drift_score",
225
+ # title="データドリフトとデータの有用性",
226
+ # color_legend_title="Species",
227
+ # x_title="特徴量の有用性",
228
+ # y_title="ドリフトスコア",
229
+ # x_lim = [-0.1, drift_df["feature_impact"].max()*1.4],
230
+ # y_lim = [-0.1, drift_df["drift_score"].max()*1.4],
231
+ # tooltip=["feature_name", "feature_impact", "drift_score"],
232
+ # caption="",
233
+ # height=500,
234
+ # width=500
235
+ # )
236
+ # else:
237
+ # gr.Markdown(
238
+ # """
239
+ # モデルの入れ替え後に予測が実行されていないためdriftは表示できません。
240
+ # """
241
+ # )
242
+
243
+ retrain_btn.click(retrain, inputs=None, outputs = [output_model_type, output_acc, output_plot])
244
+
245
+ electoric_ploting.load(lambda: datetime.datetime.now(),
246
+ None,
247
+ # c_time2,
248
+ every=3600)
249
+ dep = electoric_ploting.load(plot_prediction_result, None, plot, every=3600)
250
+
251
+ electoric_ploting.queue().launch()
252
+
253
+ plt.close()