Spaces:
Running
Running
Upload app.py
Browse files
app.py
ADDED
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import random
|
3 |
+
import matplotlib
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import pandas as pd
|
6 |
+
import shap
|
7 |
+
import lightgbm as lgb
|
8 |
+
import yaml
|
9 |
+
import numpy as np
|
10 |
+
import sys
|
11 |
+
import os
|
12 |
+
|
13 |
+
current_path = os.getcwd() + '\\' + __file__
|
14 |
+
vis_path = os.path.dirname(current_path) + '\\visualisations'
|
15 |
+
sys.path.append(vis_path)
|
16 |
+
import terravisualisation as tmvis
|
17 |
+
|
18 |
+
matplotlib.use("Agg")
|
19 |
+
|
20 |
+
bontiledict = {'SPD + 2C':'BON1',
|
21 |
+
'cult + 4C':'BON2',
|
22 |
+
'+6C':'BON3',
|
23 |
+
'+3pw 1 ship':'BON4',
|
24 |
+
'+1W + 3PW':'BON5',
|
25 |
+
'pass-vp:SA/SH*4 + 2W':'BON6',
|
26 |
+
'pass-vp:TP*2 + 1W':'BON7',
|
27 |
+
'+1P':'BON8',
|
28 |
+
'pass-vp:D*1 + 2C':'BON9',
|
29 |
+
'pass-vp: ship*3 + 3pw':'BON10'
|
30 |
+
}
|
31 |
+
|
32 |
+
bontiledict_reverse = {'BON1': 'SPD + 2C',
|
33 |
+
'BON2': 'cult + 4C',
|
34 |
+
'BON3': '+6C',
|
35 |
+
'BON4': '+3pw 1 ship',
|
36 |
+
'BON5': '+1W + 3PW',
|
37 |
+
'BON6': 'pass-vp:SA/SH*4 + 2W',
|
38 |
+
'BON7': 'pass-vp:TP*2 + 1W',
|
39 |
+
'BON8': '+1P',
|
40 |
+
'BON9': 'pass-vp:D*1 + 2C',
|
41 |
+
'BON10': 'pass-vp: ship*3 + 3pw'
|
42 |
+
}
|
43 |
+
|
44 |
+
round_tiles_dict = {'SPADE >> 2':'SCORE1',
|
45 |
+
'TOWN >> 5':'SCORE2',
|
46 |
+
'D >> 2':'SCORE3',
|
47 |
+
'SA/SH >> 5':'SCORE4',
|
48 |
+
'D >> 2':'SCORE5',
|
49 |
+
'TP >> 3':'SCORE6',
|
50 |
+
'SA/SH >> 5':'SCORE7',
|
51 |
+
'TP >> 3':'SCORE8',
|
52 |
+
'TE >> 4':'SCORE9'}
|
53 |
+
|
54 |
+
round_tiles_dict_reverse = {'SCORE1': 'SPADE >> 2',
|
55 |
+
'SCORE2': 'TOWN >> 5',
|
56 |
+
'SCORE3': 'D >> 2',
|
57 |
+
'SCORE4': 'SA/SH >> 5',
|
58 |
+
'SCORE5': 'D >> 2',
|
59 |
+
'SCORE6': 'TP >> 3',
|
60 |
+
'SCORE7': 'SA/SH >> 5',
|
61 |
+
'SCORE8': 'TP >> 3',
|
62 |
+
'SCORE9': 'TE >> 4'}
|
63 |
+
|
64 |
+
|
65 |
+
round_tiles = list(round_tiles_dict.keys())
|
66 |
+
bontiles = list(bontiledict.keys())
|
67 |
+
|
68 |
+
factions = ['Witches', 'Auren', 'Giants', 'Chaos Magicians', 'Darklings', 'Alchemists',
|
69 |
+
'Swarmlings', 'Mermaids', 'Fakirs', 'Nomads', 'Engineers', 'Dwarves', 'Halflings', 'Cultists']
|
70 |
+
|
71 |
+
players = ['2players', '3players', '4players', '5players']
|
72 |
+
|
73 |
+
maps = ['map1', 'map2', 'map3']
|
74 |
+
|
75 |
+
faction_cols = ['Yellow', 'Red', 'Grey', 'Black', 'Blue', 'Green', 'Brown']
|
76 |
+
|
77 |
+
with open('params.yaml', 'r') as fd:
|
78 |
+
params = yaml.safe_load(fd)
|
79 |
+
|
80 |
+
vpdfdir = params['prepare']['vp-data-dir']
|
81 |
+
featdfdir = params['prepare']['feature-data-dir']
|
82 |
+
pickledir = params['prepare-step2']['pickle-dir']
|
83 |
+
|
84 |
+
feature_columns = ['x0_SCORE1', 'x0_SCORE2', 'x0_SCORE3', 'x0_SCORE4', 'x0_SCORE5',
|
85 |
+
'x0_SCORE6', 'x0_SCORE7', 'x0_SCORE8', 'x0_SCORE9', 'x1_SCORE1',
|
86 |
+
'x1_SCORE2', 'x1_SCORE3', 'x1_SCORE4', 'x1_SCORE5', 'x1_SCORE6',
|
87 |
+
'x1_SCORE7', 'x1_SCORE8', 'x1_SCORE9', 'x2_SCORE1', 'x2_SCORE2',
|
88 |
+
'x2_SCORE3', 'x2_SCORE4', 'x2_SCORE5', 'x2_SCORE6', 'x2_SCORE7',
|
89 |
+
'x2_SCORE8', 'x2_SCORE9', 'x3_SCORE1', 'x3_SCORE2', 'x3_SCORE3',
|
90 |
+
'x3_SCORE4', 'x3_SCORE5', 'x3_SCORE6', 'x3_SCORE7', 'x3_SCORE8',
|
91 |
+
'x3_SCORE9', 'x4_SCORE1', 'x4_SCORE2', 'x4_SCORE3', 'x4_SCORE4',
|
92 |
+
'x4_SCORE5', 'x4_SCORE6', 'x4_SCORE7', 'x4_SCORE8', 'x4_SCORE9',
|
93 |
+
'x5_SCORE2', 'x5_SCORE3', 'x5_SCORE4', 'x5_SCORE5', 'x5_SCORE6',
|
94 |
+
'x5_SCORE7', 'x5_SCORE8', 'x5_SCORE9', 'BON1', 'BON2', 'BON3', 'BON4',
|
95 |
+
'BON5', 'BON6', 'BON7', 'BON8', 'BON9', 'BON10', 'no_players', 'red',
|
96 |
+
'blue', 'green', 'black', 'grey', 'yellow', 'brown', 'x0_map1',
|
97 |
+
'x0_map2', 'x0_map3']
|
98 |
+
|
99 |
+
|
100 |
+
|
101 |
+
def args_to_features(*args):
|
102 |
+
# round1, round2, round3, round4, round5, round6, faction, map, playerschosen, bon_tiles, fac_cols = args
|
103 |
+
Xdata = pd.DataFrame(data=np.zeros((1, len(feature_columns))), columns=feature_columns)
|
104 |
+
|
105 |
+
for arg_no, user_input in enumerate(args):
|
106 |
+
if arg_no in range(6): # if it's a round
|
107 |
+
# map back to col name
|
108 |
+
feat_label_name = f'x{arg_no}_{round_tiles_dict[user_input]}'
|
109 |
+
Xdata[feat_label_name].iloc[0] = 1
|
110 |
+
elif arg_no == 6:
|
111 |
+
faction = user_input
|
112 |
+
if faction == 'Chaos Magicians':
|
113 |
+
faction = 'chaosmagicians'
|
114 |
+
elif arg_no == 7: # map
|
115 |
+
feat_label_name = f'x0_{user_input}'
|
116 |
+
Xdata[feat_label_name].iloc[0] = 1
|
117 |
+
elif arg_no == 8: # playerschosen
|
118 |
+
Xdata['no_players'].iloc[0] = int(user_input[0])
|
119 |
+
elif arg_no == 9: # bon_tiles
|
120 |
+
for bon_tile in user_input:
|
121 |
+
Xdata[bontiledict[bon_tile]].iloc[0] = 1
|
122 |
+
elif arg_no == 9: # fac_cols
|
123 |
+
for fac_col in user_input:
|
124 |
+
Xdata[fac_col.lower()].iloc[0] = 1
|
125 |
+
|
126 |
+
return Xdata, faction
|
127 |
+
|
128 |
+
def display_map(faction, map):
|
129 |
+
map_fig = plt.figure(tight_layout=True)
|
130 |
+
|
131 |
+
x, y = tmvis.display_map(faction, plot=False)
|
132 |
+
a = map_fig.add_subplot(111)
|
133 |
+
a.hexbin(x, y, gridsize=(19, 9), cmap='magma')
|
134 |
+
a.axis('off')
|
135 |
+
return map_fig
|
136 |
+
|
137 |
+
|
138 |
+
def predict(*args):
|
139 |
+
Xdata, faction = args_to_features(*args)
|
140 |
+
|
141 |
+
modelfile = f'D://PycharmProjects/TerraBot/data/faction-picker-bot/models/{faction}_model.txt'
|
142 |
+
bst = lgb.Booster(model_file=modelfile)
|
143 |
+
|
144 |
+
return f'Final score: {round(bst.predict(Xdata)[0])}'
|
145 |
+
|
146 |
+
|
147 |
+
def interpret(*args):
|
148 |
+
Xdata, faction = args_to_features(*args)
|
149 |
+
modelfile = f'D://PycharmProjects/TerraBot/data/faction-picker-bot/models/{faction}_model.txt'
|
150 |
+
bst = lgb.Booster(model_file=modelfile)
|
151 |
+
bst.params["objective"] = "regression"
|
152 |
+
explainer = shap.Explainer(bst)
|
153 |
+
|
154 |
+
copycols = []
|
155 |
+
for ii, column in enumerate(Xdata.columns):
|
156 |
+
if column[-6:] in round_tiles_dict_reverse.keys():
|
157 |
+
copycols.append(column[:3] + round_tiles_dict_reverse[column[-6:]])
|
158 |
+
elif column in bontiledict_reverse.keys():
|
159 |
+
copycols.append(bontiledict_reverse[column])
|
160 |
+
else:
|
161 |
+
copycols.append(column)
|
162 |
+
|
163 |
+
Xdata.columns = copycols
|
164 |
+
|
165 |
+
shap_values = explainer(Xdata)
|
166 |
+
fig_m = plt.figure(tight_layout=True, facecolor=(0.125,0.172,0.203))
|
167 |
+
ax = plt.gca()
|
168 |
+
ax.set_facecolor((0.125,0.172,0.203))
|
169 |
+
matplotlib.rcParams['axes.labelcolor'] = 'w'
|
170 |
+
shap.plots.waterfall(shap_values[0])
|
171 |
+
# shap.initjs()
|
172 |
+
# shap.plots.force(shap_values[0])
|
173 |
+
return fig_m
|
174 |
+
|
175 |
+
|
176 |
+
|
177 |
+
with gr.Blocks() as demo:
|
178 |
+
gr.Markdown("""
|
179 |
+
**Predict final faction score given the initial board setup 💰**: This model uses an lightgbm regression to make prediction.
|
180 |
+
The [source code for this work is here](https://github.com/guyreading/terrabot/faction-picker-bot/gradio_interface.py).
|
181 |
+
""")
|
182 |
+
with gr.Row():
|
183 |
+
with gr.Column():
|
184 |
+
faction = gr.Dropdown(
|
185 |
+
label="Faction",
|
186 |
+
choices=factions,
|
187 |
+
value=lambda: random.choice(factions),
|
188 |
+
)
|
189 |
+
|
190 |
+
round1_tile = gr.Dropdown(
|
191 |
+
label="Round 1 tile",
|
192 |
+
choices=round_tiles,
|
193 |
+
value=lambda: random.choice(round_tiles),
|
194 |
+
)
|
195 |
+
|
196 |
+
round2_tile = gr.Dropdown(
|
197 |
+
label="Round 2 tile",
|
198 |
+
choices=round_tiles,
|
199 |
+
value=lambda: random.choice(round_tiles),
|
200 |
+
)
|
201 |
+
|
202 |
+
round3_tile = gr.Dropdown(
|
203 |
+
label="Round 3 tile",
|
204 |
+
choices=round_tiles,
|
205 |
+
value=lambda: random.choice(round_tiles),
|
206 |
+
)
|
207 |
+
|
208 |
+
round4_tile = gr.Dropdown(
|
209 |
+
label="Round 4 tile",
|
210 |
+
choices=round_tiles,
|
211 |
+
value=lambda: random.choice(round_tiles),
|
212 |
+
)
|
213 |
+
|
214 |
+
round5_tile = gr.Dropdown(
|
215 |
+
label="Round 5 tile",
|
216 |
+
choices=round_tiles,
|
217 |
+
value=lambda: random.choice(round_tiles),
|
218 |
+
)
|
219 |
+
|
220 |
+
round6_tile = gr.Dropdown(
|
221 |
+
label="Round 6 tile",
|
222 |
+
choices=round_tiles,
|
223 |
+
value=lambda: random.choice(round_tiles),
|
224 |
+
)
|
225 |
+
|
226 |
+
bon_tiles_gr = gr.CheckboxGroup(label='Bonus tiles present', choices=list(bontiledict.keys()))
|
227 |
+
|
228 |
+
map = gr.Dropdown(
|
229 |
+
label="Map",
|
230 |
+
choices=maps,
|
231 |
+
value=lambda: random.choice(maps),
|
232 |
+
)
|
233 |
+
|
234 |
+
playerschosen = gr.Dropdown(
|
235 |
+
label="No. Of Players",
|
236 |
+
choices=players,
|
237 |
+
value=lambda: random.choice(players),
|
238 |
+
)
|
239 |
+
|
240 |
+
fac_cols_gr = gr.CheckboxGroup(label='Other faction colours present', choices=faction_cols)
|
241 |
+
|
242 |
+
|
243 |
+
with gr.Column():
|
244 |
+
map_plot = gr.Plot(label='Distance from home terrain: darker is further')
|
245 |
+
|
246 |
+
with gr.Row():
|
247 |
+
predict_btn = gr.Button(value="Predict")
|
248 |
+
interpret_btn = gr.Button(value="Explain")
|
249 |
+
|
250 |
+
label = gr.Label(label=f'Prediction of final VP for faction:')
|
251 |
+
plot = gr.Plot(label=f'Breakdown of prediction for faction:')
|
252 |
+
|
253 |
+
predict_btn.click(
|
254 |
+
predict,
|
255 |
+
inputs=[
|
256 |
+
round1_tile,
|
257 |
+
round2_tile,
|
258 |
+
round3_tile,
|
259 |
+
round4_tile,
|
260 |
+
round5_tile,
|
261 |
+
round6_tile,
|
262 |
+
faction,
|
263 |
+
map,
|
264 |
+
playerschosen,
|
265 |
+
bon_tiles_gr,
|
266 |
+
fac_cols_gr
|
267 |
+
],
|
268 |
+
outputs=[label],
|
269 |
+
)
|
270 |
+
interpret_btn.click(
|
271 |
+
interpret,
|
272 |
+
inputs=[
|
273 |
+
round1_tile,
|
274 |
+
round2_tile,
|
275 |
+
round3_tile,
|
276 |
+
round4_tile,
|
277 |
+
round5_tile,
|
278 |
+
round6_tile,
|
279 |
+
faction,
|
280 |
+
map,
|
281 |
+
playerschosen,
|
282 |
+
bon_tiles_gr,
|
283 |
+
fac_cols_gr
|
284 |
+
],
|
285 |
+
outputs=[plot],
|
286 |
+
)
|
287 |
+
|
288 |
+
faction.change(
|
289 |
+
display_map,
|
290 |
+
inputs=[
|
291 |
+
faction,
|
292 |
+
map,
|
293 |
+
],
|
294 |
+
outputs=[map_plot],
|
295 |
+
)
|
296 |
+
|
297 |
+
map.change(
|
298 |
+
display_map,
|
299 |
+
inputs=[
|
300 |
+
faction,
|
301 |
+
map,
|
302 |
+
],
|
303 |
+
outputs=[map_plot],
|
304 |
+
)
|
305 |
+
|
306 |
+
demo.launch()
|