NaokiOkamoto commited on
Commit
d18cc82
1 Parent(s): 81050a5

Upload train_modeling.py

Browse files
Files changed (1) hide show
  1. function/train_modeling.py +185 -0
function/train_modeling.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import gradio as gr
4
+ import datetime
5
+ from dateutil.relativedelta import relativedelta
6
+ from func import get_fish_qty, get_estat, dr_prediction_deployment
7
+
8
+ import yaml
9
+ with open('config.yaml') as file:
10
+ config = yaml.safe_load(file.read())
11
+
12
+ def create_train_data():
13
+ # ターゲットを抽出
14
+ household_survey = get_estat.get_household_survey()
15
+ expence_df = pd.DataFrame({'年月':household_survey['時間軸(月次)'].unique()})
16
+ cate='3.1 電気代'
17
+ temp_df = household_survey.loc[household_survey['品目分類(2020年改定)'] == cate]
18
+ unit = temp_df['unit'].unique()[0]
19
+ temp_df = temp_df.rename(columns={'value':f'{cate}_({unit})'})
20
+ expence_df = pd.merge(expence_df,
21
+ temp_df[['時間軸(月次)', f'{cate}_({unit})']].rename(columns={'時間軸(月次)':'年月'}),
22
+ on='年月',
23
+ how='left')
24
+ expence_df = expence_df.rename(columns={'3.1 電気代_(円)':'電気代'})
25
+ expence_df['年月'] = pd.to_datetime(expence_df['年月'], format='%Y年%m月').astype(str).apply(lambda x:''.join(x.split('-'))[:6]).astype(int)
26
+
27
+ # 原油価格を抽出し作成
28
+ oil_price_df = pd.read_excel(config['oil_price_url'], header=5)
29
+ oil_price_df = oil_price_df.rename(columns={oil_price_df.columns[0]:'年'})
30
+ oil_price_df['年'] = oil_price_df['年'].interpolate(method='ffill')
31
+ oil_price_df['年月'] = oil_price_df['年'] + oil_price_df['月'].astype(str) + '月'
32
+ oil_price_df['年月'] = pd.to_datetime(oil_price_df['年月'], format='%Y年%m月').astype(str).apply(lambda x:''.join(x.split('-'))[:6]).astype(int)
33
+
34
+ # 燃料調達価格のデータを作成
35
+ fuel_procurement_cost_df = pd.read_excel(config['fuel_procurement_cost_url'], header=4)
36
+ fuel_procurement_cost_df = fuel_procurement_cost_df.iloc[:, 3:]
37
+ for i in fuel_procurement_cost_df.columns:
38
+ if '\n' in i:
39
+ fuel_procurement_cost_df = fuel_procurement_cost_df.rename(columns={i:i.replace('\n', '')})
40
+
41
+ fuel_procurement_cost_df['燃料費調整単価適用期間'] = fuel_procurement_cost_df['燃料費調整単価適用期間'].interpolate(method='ffill')
42
+ fuel_procurement_cost_df['燃料費調整単価適用期間'] = pd.to_datetime(fuel_procurement_cost_df['燃料費調整単価適用期間'],
43
+ format='%Y年\n%m月').astype(str).apply(lambda x:''.join(x.split('-'))[:6]).astype(int)
44
+ for kind in fuel_procurement_cost_df['種別'].unique():
45
+ temp_df = fuel_procurement_cost_df.loc[fuel_procurement_cost_df['種別']==kind].drop('種別', axis=1)
46
+ temp_df = temp_df.rename(columns={temp_df.columns[0]:'年月'})
47
+ for i in temp_df.columns:
48
+ if i != '年月':
49
+ temp_df = temp_df.rename(columns={i:f'{i}_{kind}_lag1'})
50
+ temp_df[f'{i}_{kind}_lag1'] = temp_df[f'{i}_{kind}_lag1'].shift(1)
51
+ expence_df = pd.merge(expence_df,
52
+ temp_df,
53
+ on='年月',
54
+ how='left')
55
+
56
+ # 各データを結合
57
+ oil_price_df[['ブレント_lag3', 'ドバイ_lag3', 'WTI_lag3', 'OPECバスケット_lag3']] = oil_price_df[['ブレント', 'ドバイ', 'WTI', 'OPECバスケット']].shift(3)
58
+ expence_df = pd.merge(expence_df,
59
+ oil_price_df[['ブレント_lag3', 'ドバイ_lag3', 'WTI_lag3', 'OPECバスケット_lag3', '年月']],
60
+ on='年月',
61
+ how='left')
62
+
63
+ # 魚の卸売りデータを読み込み
64
+ last_time_fish_arch = pd.read_csv('data/fish_sell_ach.csv')
65
+ start_date = pd.to_datetime(str(last_time_fish_arch['date'].max()))
66
+ end_date = pd.to_datetime(today + relativedelta(days=1))
67
+ use_fish_list = config['use_fish_list']
68
+ temp_sell_ach = get_fish_qty.get_fish_price_data(start_date, end_date, use_fish_list)
69
+ temp_sell_ach['date'] = temp_sell_ach['date'].astype(int)
70
+ sell_ach = pd.concat([last_time_fish_arch,
71
+ temp_sell_ach.loc[~temp_sell_ach['date'].isin(last_time_fish_arch['date'].unique())]])
72
+ sell_ach.to_csv('data/fish_sell_ach.csv', index=False)
73
+
74
+ # trainデータの作成
75
+ sell_ach['target_date'] = sell_ach['date'].apply(lambda x:int((pd.to_datetime(str(x))+relativedelta(months=1)).strftime('%Y%m%d')))
76
+ sell_ach['年月'] = sell_ach['target_date'].astype(str).str[:6].astype(int)
77
+
78
+ col_list=['するめいか_卸売数量計(kg)',
79
+ 'いわし_卸売数量計(kg)',
80
+ 'ぶり・わらさ_卸売数量計(kg)',
81
+ '冷さけ_卸売数量計(kg)',
82
+ '塩さけ_卸売数量計(kg)',
83
+ 'さけます類_卸売数量計(kg)',
84
+ '全卸売数量計(kg)']
85
+
86
+ for shift_i in [7, 14, 21, 28]:
87
+ change_col_list = [f'{i}_lag{shift_i}' for i in col_list]
88
+ sell_ach[change_col_list] = sell_ach[col_list].shift(shift_i)
89
+
90
+ sell_ach = sell_ach.rename(columns={'date':'forecast_point'})
91
+ train_df = pd.merge(expence_df,
92
+ sell_ach,
93
+ on='年月')
94
+
95
+ return train_df
96
+
97
+
98
+ def modeling():
99
+ train_df = create_train_data()
100
+ # モデリングに必要な各設定値
101
+ ## データロボットとの接続設定
102
+ token = 'NjQwMDVmNGI0ZDQzZDFhYzI2YThmZDJiOnVZejljTXFNTXNoUnlKMStoUFhXSFdYMEZRck9lY3dobnEvRFZ1aVBHbVE9'
103
+ ### デモ環境これっぽい
104
+ endpoint = 'https://app.datarobot.com/api/v2'
105
+
106
+ ## プロジェクト名
107
+ project_name = f'{datetime.datetime.now().strftime("Y%%m%d")}_ESTYLEU_電気代予測_再学習'
108
+
109
+ ## 各種設定
110
+ ### 特徴量設定
111
+ target = '電気代'
112
+ feature_timeline = 'target_date' #時系列
113
+ not_use_feature = ['年月', 'forecast_point']
114
+ # 最適化指標
115
+ metric = 'RMSE'
116
+ ### ギャップ
117
+ gap='P0Y' # これで0?要確認
118
+ ### バックテストの数
119
+ number_of_backtests = 1
120
+ end_date = train_df[target].max()
121
+ ### 日付
122
+ holdout_end_date=pd.to_datetime(str(end_date))
123
+ holdout_start_date=holdout_end_date - relativedelta(year=1)
124
+ backtest_end_date = holdout_start_date - relativedelta(days=1)
125
+ backtest_start_date = backtest_end_date - relativedelta(years=1)
126
+ train_end_date = backtest_start_date - relativedelta(days=1)
127
+ train_start_date = pd.to_datetime(str(train_df[target].min()))
128
+
129
+ ### モデリングモード
130
+ # mode = dr.AUTOPILOT_MODE.QUICK
131
+ mode = dr.AUTOPILOT_MODE.FULL_AUTO
132
+ dr.Client(
133
+ endpoint=endpoint,
134
+ token=token
135
+ )
136
+
137
+ # バックテスト設定
138
+ backtests_setting = [dr.BacktestSpecification(
139
+ index=0,
140
+ primary_training_start_date=train_start_date,
141
+ primary_training_end_date=train_end_date,
142
+ validation_start_date=backtest_start_date,
143
+ validation_end_date=backtest_end_date
144
+ )]
145
+
146
+ spec = dr.DatetimePartitioningSpecification(
147
+ feature_timeline,
148
+ use_time_series=False,
149
+ disable_holdout=False,
150
+ holdout_start_date=holdout_start_date,
151
+ holdout_end_date=holdout_end_date,
152
+ gap_duration=gap,
153
+ number_of_backtests=number_of_backtests,
154
+ backtests=backtests_setting,
155
+ )
156
+
157
+ use_feature_list = df.columns.to_list()
158
+
159
+ print('now creating project')
160
+ project = dr.Project.create(
161
+ train_df,
162
+ project_name=project_name
163
+ )
164
+
165
+ raw = [feat_list for feat_list in project.get_featurelists() if feat_list.name == 'Informative Features'][0]
166
+ raw_features = [feat for feat in raw.features if f'{feature_timeline} ' in feat]
167
+
168
+ for i in not_use_feature:
169
+ if i in use_feature_list:
170
+ use_feature_list.remove(i)
171
+
172
+ use_feature_list = use_feature_list.extend(raw_features)
173
+ print("start modeling")
174
+ project.analyze_and_model(
175
+ target = target,
176
+ mode = mode,
177
+ partitioning_method=spec,
178
+ max_wait=3000,
179
+ worker_count=-1,
180
+ featurelist_id = project.create_featurelist('モデリング', use_feature_list).id
181
+ )
182
+ print(project.get_leaderboard_ui_permalink())
183
+ project.wait_for_autopilot()
184
+
185
+