guyar commited on
Commit
665af2f
1 Parent(s): 654e9ae

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +306 -0
app.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import random
3
+ import matplotlib
4
+ import matplotlib.pyplot as plt
5
+ import pandas as pd
6
+ import shap
7
+ import lightgbm as lgb
8
+ import yaml
9
+ import numpy as np
10
+ import sys
11
+ import os
12
+
13
+ current_path = os.getcwd() + '\\' + __file__
14
+ vis_path = os.path.dirname(current_path) + '\\visualisations'
15
+ sys.path.append(vis_path)
16
+ import terravisualisation as tmvis
17
+
18
+ matplotlib.use("Agg")
19
+
20
+ bontiledict = {'SPD + 2C':'BON1',
21
+ 'cult + 4C':'BON2',
22
+ '+6C':'BON3',
23
+ '+3pw 1 ship':'BON4',
24
+ '+1W + 3PW':'BON5',
25
+ 'pass-vp:SA/SH*4 + 2W':'BON6',
26
+ 'pass-vp:TP*2 + 1W':'BON7',
27
+ '+1P':'BON8',
28
+ 'pass-vp:D*1 + 2C':'BON9',
29
+ 'pass-vp: ship*3 + 3pw':'BON10'
30
+ }
31
+
32
+ bontiledict_reverse = {'BON1': 'SPD + 2C',
33
+ 'BON2': 'cult + 4C',
34
+ 'BON3': '+6C',
35
+ 'BON4': '+3pw 1 ship',
36
+ 'BON5': '+1W + 3PW',
37
+ 'BON6': 'pass-vp:SA/SH*4 + 2W',
38
+ 'BON7': 'pass-vp:TP*2 + 1W',
39
+ 'BON8': '+1P',
40
+ 'BON9': 'pass-vp:D*1 + 2C',
41
+ 'BON10': 'pass-vp: ship*3 + 3pw'
42
+ }
43
+
44
+ round_tiles_dict = {'SPADE >> 2':'SCORE1',
45
+ 'TOWN >> 5':'SCORE2',
46
+ 'D >> 2':'SCORE3',
47
+ 'SA/SH >> 5':'SCORE4',
48
+ 'D >> 2':'SCORE5',
49
+ 'TP >> 3':'SCORE6',
50
+ 'SA/SH >> 5':'SCORE7',
51
+ 'TP >> 3':'SCORE8',
52
+ 'TE >> 4':'SCORE9'}
53
+
54
+ round_tiles_dict_reverse = {'SCORE1': 'SPADE >> 2',
55
+ 'SCORE2': 'TOWN >> 5',
56
+ 'SCORE3': 'D >> 2',
57
+ 'SCORE4': 'SA/SH >> 5',
58
+ 'SCORE5': 'D >> 2',
59
+ 'SCORE6': 'TP >> 3',
60
+ 'SCORE7': 'SA/SH >> 5',
61
+ 'SCORE8': 'TP >> 3',
62
+ 'SCORE9': 'TE >> 4'}
63
+
64
+
65
+ round_tiles = list(round_tiles_dict.keys())
66
+ bontiles = list(bontiledict.keys())
67
+
68
+ factions = ['Witches', 'Auren', 'Giants', 'Chaos Magicians', 'Darklings', 'Alchemists',
69
+ 'Swarmlings', 'Mermaids', 'Fakirs', 'Nomads', 'Engineers', 'Dwarves', 'Halflings', 'Cultists']
70
+
71
+ players = ['2players', '3players', '4players', '5players']
72
+
73
+ maps = ['map1', 'map2', 'map3']
74
+
75
+ faction_cols = ['Yellow', 'Red', 'Grey', 'Black', 'Blue', 'Green', 'Brown']
76
+
77
+ with open('params.yaml', 'r') as fd:
78
+ params = yaml.safe_load(fd)
79
+
80
+ vpdfdir = params['prepare']['vp-data-dir']
81
+ featdfdir = params['prepare']['feature-data-dir']
82
+ pickledir = params['prepare-step2']['pickle-dir']
83
+
84
+ feature_columns = ['x0_SCORE1', 'x0_SCORE2', 'x0_SCORE3', 'x0_SCORE4', 'x0_SCORE5',
85
+ 'x0_SCORE6', 'x0_SCORE7', 'x0_SCORE8', 'x0_SCORE9', 'x1_SCORE1',
86
+ 'x1_SCORE2', 'x1_SCORE3', 'x1_SCORE4', 'x1_SCORE5', 'x1_SCORE6',
87
+ 'x1_SCORE7', 'x1_SCORE8', 'x1_SCORE9', 'x2_SCORE1', 'x2_SCORE2',
88
+ 'x2_SCORE3', 'x2_SCORE4', 'x2_SCORE5', 'x2_SCORE6', 'x2_SCORE7',
89
+ 'x2_SCORE8', 'x2_SCORE9', 'x3_SCORE1', 'x3_SCORE2', 'x3_SCORE3',
90
+ 'x3_SCORE4', 'x3_SCORE5', 'x3_SCORE6', 'x3_SCORE7', 'x3_SCORE8',
91
+ 'x3_SCORE9', 'x4_SCORE1', 'x4_SCORE2', 'x4_SCORE3', 'x4_SCORE4',
92
+ 'x4_SCORE5', 'x4_SCORE6', 'x4_SCORE7', 'x4_SCORE8', 'x4_SCORE9',
93
+ 'x5_SCORE2', 'x5_SCORE3', 'x5_SCORE4', 'x5_SCORE5', 'x5_SCORE6',
94
+ 'x5_SCORE7', 'x5_SCORE8', 'x5_SCORE9', 'BON1', 'BON2', 'BON3', 'BON4',
95
+ 'BON5', 'BON6', 'BON7', 'BON8', 'BON9', 'BON10', 'no_players', 'red',
96
+ 'blue', 'green', 'black', 'grey', 'yellow', 'brown', 'x0_map1',
97
+ 'x0_map2', 'x0_map3']
98
+
99
+
100
+
101
+ def args_to_features(*args):
102
+ # round1, round2, round3, round4, round5, round6, faction, map, playerschosen, bon_tiles, fac_cols = args
103
+ Xdata = pd.DataFrame(data=np.zeros((1, len(feature_columns))), columns=feature_columns)
104
+
105
+ for arg_no, user_input in enumerate(args):
106
+ if arg_no in range(6): # if it's a round
107
+ # map back to col name
108
+ feat_label_name = f'x{arg_no}_{round_tiles_dict[user_input]}'
109
+ Xdata[feat_label_name].iloc[0] = 1
110
+ elif arg_no == 6:
111
+ faction = user_input
112
+ if faction == 'Chaos Magicians':
113
+ faction = 'chaosmagicians'
114
+ elif arg_no == 7: # map
115
+ feat_label_name = f'x0_{user_input}'
116
+ Xdata[feat_label_name].iloc[0] = 1
117
+ elif arg_no == 8: # playerschosen
118
+ Xdata['no_players'].iloc[0] = int(user_input[0])
119
+ elif arg_no == 9: # bon_tiles
120
+ for bon_tile in user_input:
121
+ Xdata[bontiledict[bon_tile]].iloc[0] = 1
122
+ elif arg_no == 9: # fac_cols
123
+ for fac_col in user_input:
124
+ Xdata[fac_col.lower()].iloc[0] = 1
125
+
126
+ return Xdata, faction
127
+
128
+ def display_map(faction, map):
129
+ map_fig = plt.figure(tight_layout=True)
130
+
131
+ x, y = tmvis.display_map(faction, plot=False)
132
+ a = map_fig.add_subplot(111)
133
+ a.hexbin(x, y, gridsize=(19, 9), cmap='magma')
134
+ a.axis('off')
135
+ return map_fig
136
+
137
+
138
+ def predict(*args):
139
+ Xdata, faction = args_to_features(*args)
140
+
141
+ modelfile = f'D://PycharmProjects/TerraBot/data/faction-picker-bot/models/{faction}_model.txt'
142
+ bst = lgb.Booster(model_file=modelfile)
143
+
144
+ return f'Final score: {round(bst.predict(Xdata)[0])}'
145
+
146
+
147
+ def interpret(*args):
148
+ Xdata, faction = args_to_features(*args)
149
+ modelfile = f'D://PycharmProjects/TerraBot/data/faction-picker-bot/models/{faction}_model.txt'
150
+ bst = lgb.Booster(model_file=modelfile)
151
+ bst.params["objective"] = "regression"
152
+ explainer = shap.Explainer(bst)
153
+
154
+ copycols = []
155
+ for ii, column in enumerate(Xdata.columns):
156
+ if column[-6:] in round_tiles_dict_reverse.keys():
157
+ copycols.append(column[:3] + round_tiles_dict_reverse[column[-6:]])
158
+ elif column in bontiledict_reverse.keys():
159
+ copycols.append(bontiledict_reverse[column])
160
+ else:
161
+ copycols.append(column)
162
+
163
+ Xdata.columns = copycols
164
+
165
+ shap_values = explainer(Xdata)
166
+ fig_m = plt.figure(tight_layout=True, facecolor=(0.125,0.172,0.203))
167
+ ax = plt.gca()
168
+ ax.set_facecolor((0.125,0.172,0.203))
169
+ matplotlib.rcParams['axes.labelcolor'] = 'w'
170
+ shap.plots.waterfall(shap_values[0])
171
+ # shap.initjs()
172
+ # shap.plots.force(shap_values[0])
173
+ return fig_m
174
+
175
+
176
+
177
+ with gr.Blocks() as demo:
178
+ gr.Markdown("""
179
+ **Predict final faction score given the initial board setup 💰**: This model uses an lightgbm regression to make prediction.
180
+ The [source code for this work is here](https://github.com/guyreading/terrabot/faction-picker-bot/gradio_interface.py).
181
+ """)
182
+ with gr.Row():
183
+ with gr.Column():
184
+ faction = gr.Dropdown(
185
+ label="Faction",
186
+ choices=factions,
187
+ value=lambda: random.choice(factions),
188
+ )
189
+
190
+ round1_tile = gr.Dropdown(
191
+ label="Round 1 tile",
192
+ choices=round_tiles,
193
+ value=lambda: random.choice(round_tiles),
194
+ )
195
+
196
+ round2_tile = gr.Dropdown(
197
+ label="Round 2 tile",
198
+ choices=round_tiles,
199
+ value=lambda: random.choice(round_tiles),
200
+ )
201
+
202
+ round3_tile = gr.Dropdown(
203
+ label="Round 3 tile",
204
+ choices=round_tiles,
205
+ value=lambda: random.choice(round_tiles),
206
+ )
207
+
208
+ round4_tile = gr.Dropdown(
209
+ label="Round 4 tile",
210
+ choices=round_tiles,
211
+ value=lambda: random.choice(round_tiles),
212
+ )
213
+
214
+ round5_tile = gr.Dropdown(
215
+ label="Round 5 tile",
216
+ choices=round_tiles,
217
+ value=lambda: random.choice(round_tiles),
218
+ )
219
+
220
+ round6_tile = gr.Dropdown(
221
+ label="Round 6 tile",
222
+ choices=round_tiles,
223
+ value=lambda: random.choice(round_tiles),
224
+ )
225
+
226
+ bon_tiles_gr = gr.CheckboxGroup(label='Bonus tiles present', choices=list(bontiledict.keys()))
227
+
228
+ map = gr.Dropdown(
229
+ label="Map",
230
+ choices=maps,
231
+ value=lambda: random.choice(maps),
232
+ )
233
+
234
+ playerschosen = gr.Dropdown(
235
+ label="No. Of Players",
236
+ choices=players,
237
+ value=lambda: random.choice(players),
238
+ )
239
+
240
+ fac_cols_gr = gr.CheckboxGroup(label='Other faction colours present', choices=faction_cols)
241
+
242
+
243
+ with gr.Column():
244
+ map_plot = gr.Plot(label='Distance from home terrain: darker is further')
245
+
246
+ with gr.Row():
247
+ predict_btn = gr.Button(value="Predict")
248
+ interpret_btn = gr.Button(value="Explain")
249
+
250
+ label = gr.Label(label=f'Prediction of final VP for faction:')
251
+ plot = gr.Plot(label=f'Breakdown of prediction for faction:')
252
+
253
+ predict_btn.click(
254
+ predict,
255
+ inputs=[
256
+ round1_tile,
257
+ round2_tile,
258
+ round3_tile,
259
+ round4_tile,
260
+ round5_tile,
261
+ round6_tile,
262
+ faction,
263
+ map,
264
+ playerschosen,
265
+ bon_tiles_gr,
266
+ fac_cols_gr
267
+ ],
268
+ outputs=[label],
269
+ )
270
+ interpret_btn.click(
271
+ interpret,
272
+ inputs=[
273
+ round1_tile,
274
+ round2_tile,
275
+ round3_tile,
276
+ round4_tile,
277
+ round5_tile,
278
+ round6_tile,
279
+ faction,
280
+ map,
281
+ playerschosen,
282
+ bon_tiles_gr,
283
+ fac_cols_gr
284
+ ],
285
+ outputs=[plot],
286
+ )
287
+
288
+ faction.change(
289
+ display_map,
290
+ inputs=[
291
+ faction,
292
+ map,
293
+ ],
294
+ outputs=[map_plot],
295
+ )
296
+
297
+ map.change(
298
+ display_map,
299
+ inputs=[
300
+ faction,
301
+ map,
302
+ ],
303
+ outputs=[map_plot],
304
+ )
305
+
306
+ demo.launch()