guyar's picture
Fix: add share=True in demo.launch in app.py
598ec00
raw
history blame
9.9 kB
import gradio as gr
import random
import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
import shap
import lightgbm as lgb
import yaml
import numpy as np
import os
from visualisations import terravisualisation as tmvis
matplotlib.use("Agg")
bontiledict = {'SPD + 2C':'BON1',
'cult + 4C':'BON2',
'+6C':'BON3',
'+3pw 1 ship':'BON4',
'+1W + 3PW':'BON5',
'pass-vp:SA/SH*4 + 2W':'BON6',
'pass-vp:TP*2 + 1W':'BON7',
'+1P':'BON8',
'pass-vp:D*1 + 2C':'BON9',
'pass-vp: ship*3 + 3pw':'BON10'
}
bontiledict_reverse = {'BON1': 'SPD + 2C',
'BON2': 'cult + 4C',
'BON3': '+6C',
'BON4': '+3pw 1 ship',
'BON5': '+1W + 3PW',
'BON6': 'pass-vp:SA/SH*4 + 2W',
'BON7': 'pass-vp:TP*2 + 1W',
'BON8': '+1P',
'BON9': 'pass-vp:D*1 + 2C',
'BON10': 'pass-vp: ship*3 + 3pw'
}
round_tiles_dict = {'SPADE >> 2':'SCORE1',
'TOWN >> 5':'SCORE2',
'D >> 2':'SCORE3',
'SA/SH >> 5':'SCORE4',
'D >> 2':'SCORE5',
'TP >> 3':'SCORE6',
'SA/SH >> 5':'SCORE7',
'TP >> 3':'SCORE8',
'TE >> 4':'SCORE9'}
round_tiles_dict_reverse = {'SCORE1': 'SPADE >> 2',
'SCORE2': 'TOWN >> 5',
'SCORE3': 'D >> 2',
'SCORE4': 'SA/SH >> 5',
'SCORE5': 'D >> 2',
'SCORE6': 'TP >> 3',
'SCORE7': 'SA/SH >> 5',
'SCORE8': 'TP >> 3',
'SCORE9': 'TE >> 4'}
round_tiles = list(round_tiles_dict.keys())
round6_tiles = round_tiles.copy()
round6_tiles.remove('SPADE >> 2')
bontiles = list(bontiledict.keys())
factions = ['Witches', 'Auren', 'Giants', 'Chaos Magicians', 'Darklings', 'Alchemists',
'Swarmlings', 'Mermaids', 'Fakirs', 'Nomads', 'Engineers', 'Dwarves', 'Halflings', 'Cultists']
players = ['2players', '3players', '4players', '5players']
maps = ['map1', 'map2', 'map3']
faction_cols = ['Yellow', 'Red', 'Grey', 'Black', 'Blue', 'Green', 'Brown']
with open('params.yaml', 'r') as fd:
params = yaml.safe_load(fd)
vpdfdir = params['prepare']['vp-data-dir']
featdfdir = params['prepare']['feature-data-dir']
pickledir = params['prepare-step2']['pickle-dir']
feature_columns = ['x0_SCORE1', 'x0_SCORE2', 'x0_SCORE3', 'x0_SCORE4', 'x0_SCORE5',
'x0_SCORE6', 'x0_SCORE7', 'x0_SCORE8', 'x0_SCORE9', 'x1_SCORE1',
'x1_SCORE2', 'x1_SCORE3', 'x1_SCORE4', 'x1_SCORE5', 'x1_SCORE6',
'x1_SCORE7', 'x1_SCORE8', 'x1_SCORE9', 'x2_SCORE1', 'x2_SCORE2',
'x2_SCORE3', 'x2_SCORE4', 'x2_SCORE5', 'x2_SCORE6', 'x2_SCORE7',
'x2_SCORE8', 'x2_SCORE9', 'x3_SCORE1', 'x3_SCORE2', 'x3_SCORE3',
'x3_SCORE4', 'x3_SCORE5', 'x3_SCORE6', 'x3_SCORE7', 'x3_SCORE8',
'x3_SCORE9', 'x4_SCORE1', 'x4_SCORE2', 'x4_SCORE3', 'x4_SCORE4',
'x4_SCORE5', 'x4_SCORE6', 'x4_SCORE7', 'x4_SCORE8', 'x4_SCORE9',
'x5_SCORE2', 'x5_SCORE3', 'x5_SCORE4', 'x5_SCORE5', 'x5_SCORE6',
'x5_SCORE7', 'x5_SCORE8', 'x5_SCORE9', 'BON1', 'BON2', 'BON3', 'BON4',
'BON5', 'BON6', 'BON7', 'BON8', 'BON9', 'BON10', 'no_players', 'red',
'blue', 'green', 'black', 'grey', 'yellow', 'brown', 'x0_map1',
'x0_map2', 'x0_map3']
def args_to_features(*args):
# round1, round2, round3, round4, round5, round6, faction, map, playerschosen, bon_tiles, fac_cols = args
Xdata = pd.DataFrame(data=np.zeros((1, len(feature_columns))), columns=feature_columns)
for arg_no, user_input in enumerate(args):
if arg_no in range(6): # if it's a round
# map back to col name
feat_label_name = f'x{arg_no}_{round_tiles_dict[user_input]}'
Xdata[feat_label_name].iloc[0] = 1
elif arg_no == 6:
faction = user_input
if faction == 'Chaos Magicians':
faction = 'chaosmagicians'
elif arg_no == 7: # map
feat_label_name = f'x0_{user_input}'
Xdata[feat_label_name].iloc[0] = 1
elif arg_no == 8: # playerschosen
Xdata['no_players'].iloc[0] = int(user_input[0])
elif arg_no == 9: # bon_tiles
for bon_tile in user_input:
Xdata[bontiledict[bon_tile]].iloc[0] = 1
elif arg_no == 9: # fac_cols
for fac_col in user_input:
Xdata[fac_col.lower()].iloc[0] = 1
return Xdata, faction
def display_map(faction, map):
map_fig = plt.figure(tight_layout=True)
x, y = tmvis.display_map(faction, plot=False)
a = map_fig.add_subplot(111)
a.hexbin(x, y, gridsize=(19, 9), cmap='magma')
a.axis('off')
return map_fig
def predict(*args):
Xdata, faction = args_to_features(*args)
modelfile = f'{os.getcwd()}/data/faction-picker-bot/models/{faction.lower()}_model.txt'
bst = lgb.Booster(model_file=modelfile)
return f'Final score: {round(bst.predict(Xdata)[0])}'
def interpret(*args):
Xdata, faction = args_to_features(*args)
modelfile = f'{os.getcwd()}/data/faction-picker-bot/models/{faction.lower()}_model.txt'
bst = lgb.Booster(model_file=modelfile)
bst.params["objective"] = "regression"
explainer = shap.Explainer(bst)
copycols = []
for ii, column in enumerate(Xdata.columns):
if column[-6:] in round_tiles_dict_reverse.keys():
copycols.append(column[:3] + round_tiles_dict_reverse[column[-6:]])
elif column in bontiledict_reverse.keys():
copycols.append(bontiledict_reverse[column])
else:
copycols.append(column)
Xdata.columns = copycols
shap_values = explainer(Xdata)
fig_m = plt.figure(tight_layout=True, facecolor=(0.125,0.172,0.203))
ax = plt.gca()
ax.set_facecolor((0.125,0.172,0.203))
matplotlib.rcParams['axes.labelcolor'] = 'w'
shap.plots.waterfall(shap_values[0])
# shap.initjs()
# shap.plots.force(shap_values[0])
return fig_m
with gr.Blocks() as demo:
gr.Markdown("""
**Predict final faction score given the initial board setup 💰**: This model uses an lightgbm regression to make prediction.
The [source code for this work is here](https://github.com/guyreading/terrabot/blob/main/app.py).
""")
with gr.Row():
with gr.Column():
faction = gr.Dropdown(
label="Faction",
choices=factions,
value=lambda: random.choice(factions),
)
round1_tile = gr.Dropdown(
label="Round 1 tile",
choices=round_tiles,
value=lambda: random.choice(round_tiles),
)
round2_tile = gr.Dropdown(
label="Round 2 tile",
choices=round_tiles,
value=lambda: random.choice(round_tiles),
)
round3_tile = gr.Dropdown(
label="Round 3 tile",
choices=round_tiles,
value=lambda: random.choice(round_tiles),
)
round4_tile = gr.Dropdown(
label="Round 4 tile",
choices=round_tiles,
value=lambda: random.choice(round_tiles),
)
round5_tile = gr.Dropdown(
label="Round 5 tile",
choices=round_tiles,
value=lambda: random.choice(round_tiles),
)
round6_tile = gr.Dropdown(
label="Round 6 tile",
choices=round6_tiles,
value=lambda: random.choice(round6_tiles),
)
bon_tiles_gr = gr.CheckboxGroup(label='Bonus tiles present', choices=list(bontiledict.keys()))
map = gr.Dropdown(
label="Map",
choices=maps,
value=lambda: random.choice(maps),
)
playerschosen = gr.Dropdown(
label="No. Of Players",
choices=players,
value=lambda: random.choice(players),
)
fac_cols_gr = gr.CheckboxGroup(label='Other faction colours present', choices=faction_cols)
with gr.Column():
map_plot = gr.Plot(label='Distance from home terrain: darker is further')
with gr.Row():
predict_btn = gr.Button(value="Predict")
interpret_btn = gr.Button(value="Explain")
label = gr.Label(label=f'Prediction of final VP for faction:')
plot = gr.Plot(label=f'Breakdown of prediction for faction:')
predict_btn.click(
predict,
inputs=[
round1_tile,
round2_tile,
round3_tile,
round4_tile,
round5_tile,
round6_tile,
faction,
map,
playerschosen,
bon_tiles_gr,
fac_cols_gr
],
outputs=[label],
)
interpret_btn.click(
interpret,
inputs=[
round1_tile,
round2_tile,
round3_tile,
round4_tile,
round5_tile,
round6_tile,
faction,
map,
playerschosen,
bon_tiles_gr,
fac_cols_gr
],
outputs=[plot],
)
faction.change(
display_map,
inputs=[
faction,
map,
],
outputs=[map_plot],
)
map.change(
display_map,
inputs=[
faction,
map,
],
outputs=[map_plot],
)
demo.launch(share=True)