import torch import numpy as np import gradio as gr import matplotlib.pylab as plt import torch.nn.functional as F from vae import HVAE from datasets import morphomnist, ukbb, mimic, get_attr_max_min from pgm.flow_pgm import MorphoMNISTPGM, FlowPGM, ChestPGM from app_utils import ( mnist_graph, brain_graph, chest_graph, vae_preprocess, normalize, preprocess_brain, get_fig_arr, postprocess, MidpointNormalize, ) DATA, MODELS = {}, {} for k in ["Morpho-MNIST", "Brain MRI", "Chest X-ray"]: DATA[k], MODELS[k] = {}, {} # mnist DIGITS = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"] # brain MRISEQ_CAT = ["T1", "T2-FLAIR"] # 0,1 SEX_CAT = ["female", "male"] # 0,1 HEIGHT, WIDTH = 270, 270 # chest SEX_CAT_CHEST = ["male", "female"] # 0,1 RACE_CAT = ["white", "asian", "black"] # 0,1,2 FIND_CAT = ["no disease", "pleural effusion"] DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") class Hparams: def update(self, dict): for k, v in dict.items(): setattr(self, k, v) def get_paths(dataset_id): if "MNIST" in dataset_id: data_path = "./data/morphomnist" pgm_path = "./checkpoints/t_i_d/sup_pgm/checkpoint.pt" vae_path = "./checkpoints/t_i_d/dgauss_cond_big_beta1_dropexo/checkpoint.pt" elif "Brain" in dataset_id: data_path = "./data/ukbb_subset" pgm_path = "./checkpoints/m_b_v_s/sup_pgm/checkpoint.pt" vae_path = "./checkpoints/m_b_v_s/ukbb192_beta5_dgauss_b33/checkpoint.pt" elif "Chest" in dataset_id: data_path = "./data/mimic_subset" pgm_path = "./checkpoints/a_r_s_f/sup_pgm_mimic/checkpoint.pt" vae_path = [ "./checkpoints/a_r_s_f/mimic_beta9_gelu_dgauss_1_lr3/checkpoint.pt", # base vae "./checkpoints/a_r_s_f/mimic_dscm_lr_1e5_lagrange_lr_1_damping_10/6500_checkpoint.pt", # cf trained DSCM ] return data_path, vae_path, pgm_path def load_pgm(dataset_id, pgm_path): checkpoint = torch.load(pgm_path, map_location=DEVICE) args = Hparams() args.update(checkpoint["hparams"]) args.device = DEVICE if "MNIST" in dataset_id: pgm = MorphoMNISTPGM(args).to(args.device) elif "Brain" in dataset_id: pgm = FlowPGM(args).to(args.device) elif "Chest" in dataset_id: pgm = ChestPGM(args).to(args.device) pgm.load_state_dict(checkpoint["ema_model_state_dict"]) MODELS[dataset_id]["pgm"] = pgm MODELS[dataset_id]["pgm_args"] = args def load_vae(dataset_id, vae_path): if "Chest" in dataset_id: vae_path, dscm_path = vae_path[0], vae_path[1] checkpoint = torch.load(vae_path, map_location=DEVICE) args = Hparams() args.update(checkpoint["hparams"]) # backwards compatibility hack if not hasattr(args, "vae"): args.vae = "hierarchical" if not hasattr(args, "cond_prior"): args.cond_prior = False if hasattr(args, "free_bits"): args.kl_free_bits = args.free_bits args.device = DEVICE vae = HVAE(args).to(args.device) if "Chest" in dataset_id: dscm_ckpt = torch.load(dscm_path, map_location=DEVICE) vae.load_state_dict( { k[4:]: v for k, v in dscm_ckpt["ema_model_state_dict"].items() if "vae." in k } ) else: vae.load_state_dict(checkpoint["ema_model_state_dict"]) MODELS[dataset_id]["vae"] = vae MODELS[dataset_id]["vae_args"] = args def get_dataloader(dataset_id, data_path): MODELS[dataset_id]["pgm_args"].data_dir = data_path args = MODELS[dataset_id]["pgm_args"] if "MNIST" in dataset_id: datasets = morphomnist(args) elif "Brain" in dataset_id: datasets = ukbb(args) elif "Chest" in dataset_id: datasets = mimic(args) DATA[dataset_id]["test"] = torch.utils.data.DataLoader( datasets["test"], shuffle=False, batch_size=args.bs, num_workers=4 ) def load_dataset(dataset_id): data_path, _, pgm_path = get_paths(dataset_id) checkpoint = torch.load(pgm_path, map_location=DEVICE) args = Hparams() args.update(checkpoint["hparams"]) args.device = DEVICE MODELS[dataset_id]["pgm_args"] = args get_dataloader(dataset_id, data_path) def load_model(dataset_id): _, vae_path, pgm_path = get_paths(dataset_id) load_pgm(dataset_id, pgm_path) load_vae(dataset_id, vae_path) @torch.no_grad() def counterfactual_inference(dataset_id, obs, do_pa): pa = {k: v.clone() for k, v in obs.items() if k != "x"} cf_pa = MODELS[dataset_id]["pgm"].counterfactual( obs=pa, intervention=do_pa, num_particles=1 ) args, vae = MODELS[dataset_id]["vae_args"], MODELS[dataset_id]["vae"] _pa = vae_preprocess(args, {k: v.clone() for k, v in pa.items()}) _cf_pa = vae_preprocess(args, {k: v.clone() for k, v in cf_pa.items()}) z_t = 0.1 if "mnist" in args.hps else 1.0 z = vae.abduct(x=obs["x"], parents=_pa, t=z_t) if vae.cond_prior: z = [z[j]["z"] for j in range(len(z))] px_loc, px_scale = vae.forward_latents(latents=z, parents=_pa) cf_loc, cf_scale = vae.forward_latents(latents=z, parents=_cf_pa) u = (obs["x"] - px_loc) / px_scale.clamp(min=1e-12) u_t = 0.1 if "mnist" in args.hps else 1.0 # cf sampling temp cf_scale = cf_scale * u_t cf_x = torch.clamp(cf_loc + cf_scale * u, min=-1, max=1) return {"cf_x": cf_x, "rec_x": px_loc, "cf_pa": cf_pa} def get_obs_item(dataset_id, idx=None): if idx is None: n_test = len(DATA[dataset_id]["test"].dataset) idx = torch.randperm(n_test)[0] idx = int(idx) return idx, DATA[dataset_id]["test"].dataset.__getitem__(idx) def get_mnist_obs(idx=None): dataset_id = "Morpho-MNIST" if not DATA[dataset_id]: load_dataset(dataset_id) idx, obs = get_obs_item(dataset_id, idx) x = get_fig_arr(obs["x"].clone().squeeze().numpy()) t = (obs["thickness"].clone() + 1) / 2 * (6.255515 - 0.87598526) + 0.87598526 i = (obs["intensity"].clone() + 1) / 2 * (254.90317 - 66.601204) + 66.601204 y = DIGITS[obs["digit"].clone().argmax(-1)] return (idx, x, float(np.round(t, 2)), float(np.round(i, 2)), y) def get_brain_obs(idx=None): dataset_id = "Brain MRI" if not DATA[dataset_id]: load_dataset(dataset_id) idx, obs = get_obs_item(dataset_id, idx) x = get_fig_arr(obs["x"].clone().squeeze().numpy()) m = MRISEQ_CAT[int(obs["mri_seq"].clone().item())] s = SEX_CAT[int(obs["sex"].clone().item())] a = obs["age"].clone().item() b = obs["brain_volume"].clone().item() / 1000 # in ml v = obs["ventricle_volume"].clone().item() / 1000 # in ml return (idx, x, m, s, a, float(np.round(b, 2)), float(np.round(v, 2))) def get_chest_obs(idx=None): dataset_id = "Chest X-ray" if not DATA[dataset_id]: load_dataset(dataset_id) idx, obs = get_obs_item(dataset_id, idx) x = get_fig_arr(postprocess(obs["x"].clone())) s = SEX_CAT_CHEST[int(obs["sex"].clone().squeeze().numpy())] f = FIND_CAT[int(obs["finding"].clone().squeeze().numpy())] r = RACE_CAT[obs["race"].clone().squeeze().numpy().argmax(-1)] a = (obs["age"].clone().squeeze().numpy() + 1) * 50 return (idx, x, r, s, f, float(np.round(a, 1))) def infer_mnist_cf(*args): dataset_id = "Morpho-MNIST" idx, _, t, i, y, do_t, do_i, do_y = args n_particles = 32 # preprocess obs = DATA[dataset_id]["test"].dataset.__getitem__(int(idx)) obs["x"] = (obs["x"] - 127.5) / 127.5 for k, v in obs.items(): obs[k] = v.view(1, 1) if len(v.shape) < 1 else v.unsqueeze(0) obs[k] = obs[k].to(MODELS[dataset_id]["vae_args"].device).float() if n_particles > 1: ndims = (1,) * 3 if k == "x" else (1,) obs[k] = obs[k].repeat(n_particles, *ndims) # intervention(s) do_pa = {} if do_t: do_pa["thickness"] = torch.tensor( normalize(t, x_max=6.255515, x_min=0.87598526) ).view(1, 1) if do_i: do_pa["intensity"] = torch.tensor( normalize(i, x_max=254.90317, x_min=66.601204) ).view(1, 1) if do_y: do_pa["digit"] = F.one_hot(torch.tensor(DIGITS.index(y)), num_classes=10).view( 1, 10 ) for k, v in do_pa.items(): do_pa[k] = ( v.to(MODELS[dataset_id]["vae_args"].device).float().repeat(n_particles, 1) ) # infer counterfactual out = counterfactual_inference(dataset_id, obs, do_pa) # avg cf particles cf_x = out["cf_x"].mean(0) cf_x_std = out["cf_x"].std(0) rec_x = out["rec_x"].mean(0) cf_t = out["cf_pa"]["thickness"].mean(0) cf_i = out["cf_pa"]["intensity"].mean(0) cf_y = out["cf_pa"]["digit"].mean(0) # post process cf_x = postprocess(cf_x) cf_x_std = cf_x_std.squeeze().detach().cpu().numpy() rec_x = postprocess(rec_x) cf_t = np.round((cf_t.item() + 1) / 2 * (6.255515 - 0.87598526) + 0.87598526, 2) cf_i = np.round((cf_i.item() + 1) / 2 * (254.90317 - 66.601204) + 66.601204, 2) cf_y = DIGITS[cf_y.argmax(-1)] # plots # plt.close('all') effect = cf_x - rec_x effect = get_fig_arr( effect, cmap="RdBu_r", norm=MidpointNormalize(vmin=-255, midpoint=0, vmax=255) ) cf_x = get_fig_arr(cf_x) cf_x_std = get_fig_arr(cf_x_std, cmap="jet") return (cf_x, cf_x_std, effect, cf_t, cf_i, cf_y) def infer_brain_cf(*args): dataset_id = "Brain MRI" idx, _, m, s, a, b, v = args[:7] do_m, do_s, do_a, do_b, do_v = args[7:] n_particles = 16 # preprocessing obs = DATA[dataset_id]["test"].dataset.__getitem__(int(idx)) obs = preprocess_brain(MODELS[dataset_id]["vae_args"], obs) for k, _v in obs.items(): if n_particles > 1: ndims = (1,) * 3 if k == "x" else (1,) obs[k] = _v.repeat(n_particles, *ndims) # interventions(s) do_pa = {} if do_m: do_pa["mri_seq"] = torch.tensor(MRISEQ_CAT.index(m)).view(1, 1) if do_s: do_pa["sex"] = torch.tensor(SEX_CAT.index(s)).view(1, 1) if do_a: do_pa["age"] = torch.tensor(a).view(1, 1) if do_b: do_pa["brain_volume"] = torch.tensor(b * 1000).view(1, 1) if do_v: do_pa["ventricle_volume"] = torch.tensor(v * 1000).view(1, 1) # normalize continuous attributes for k in ["age", "brain_volume", "ventricle_volume"]: if k in do_pa.keys(): k_max, k_min = get_attr_max_min(k) do_pa[k] = (do_pa[k] - k_min) / (k_max - k_min) # [0,1] do_pa[k] = 2 * do_pa[k] - 1 # [-1,1] for k, _v in do_pa.items(): do_pa[k] = ( _v.to(MODELS[dataset_id]["vae_args"].device).float().repeat(n_particles, 1) ) # infer counterfactual out = counterfactual_inference(dataset_id, obs, do_pa) # avg cf particles cf_x = out["cf_x"].mean(0) cf_x_std = out["cf_x"].std(0) rec_x = out["rec_x"].mean(0) cf_m = out["cf_pa"]["mri_seq"].mean(0) cf_s = out["cf_pa"]["sex"].mean(0) # post process cf_x = postprocess(cf_x) cf_x_std = cf_x_std.squeeze().detach().cpu().numpy() rec_x = postprocess(rec_x) cf_m = MRISEQ_CAT[int(cf_m.item())] cf_s = SEX_CAT[int(cf_s.item())] cf_ = {} for k in ["age", "brain_volume", "ventricle_volume"]: # unnormalize k_max, k_min = get_attr_max_min(k) cf_[k] = (out["cf_pa"][k].mean(0).item() + 1) / 2 * (k_max - k_min) + k_min # plots # plt.close('all') effect = cf_x - rec_x effect = get_fig_arr( effect, cmap="RdBu_r", norm=MidpointNormalize(vmin=effect.min(), midpoint=0, vmax=effect.max()), ) cf_x = get_fig_arr(cf_x) cf_x_std = get_fig_arr(cf_x_std, cmap="jet") return ( cf_x, cf_x_std, effect, cf_m, cf_s, np.round(cf_["age"], 1), np.round(cf_["brain_volume"] / 1000, 2), np.round(cf_["ventricle_volume"] / 1000, 2), ) def infer_chest_cf(*args): dataset_id = "Chest X-ray" idx, _, r, s, f, a = args[:6] do_r, do_s, do_f, do_a = args[6:] n_particles = 16 # preprocessing obs = DATA[dataset_id]["test"].dataset.__getitem__(int(idx)) for k, v in obs.items(): obs[k] = v.to(MODELS[dataset_id]["vae_args"].device).float() if n_particles > 1: ndims = (1,) * 3 if k == "x" else (1,) obs[k] = obs[k].repeat(n_particles, *ndims) # intervention(s) do_pa = {} with torch.no_grad(): if do_s: do_pa["sex"] = torch.tensor(SEX_CAT_CHEST.index(s)).view(1, 1) if do_f: do_pa["finding"] = torch.tensor(FIND_CAT.index(f)).view(1, 1) if do_r: do_pa["race"] = F.one_hot( torch.tensor(RACE_CAT.index(r)), num_classes=3 ).view(1, 3) if do_a: do_pa["age"] = torch.tensor(a / 100 * 2 - 1).view(1, 1) for k, v in do_pa.items(): do_pa[k] = ( v.to(MODELS[dataset_id]["vae_args"].device).float().repeat(n_particles, 1) ) # infer counterfactual out = counterfactual_inference(dataset_id, obs, do_pa) # avg cf particles cf_x = out["cf_x"].mean(0) cf_x_std = out["cf_x"].std(0) rec_x = out["rec_x"].mean(0) cf_r = out["cf_pa"]["race"].mean(0) cf_s = out["cf_pa"]["sex"].mean(0) cf_f = out["cf_pa"]["finding"].mean(0) cf_a = out["cf_pa"]["age"].mean(0) # post process cf_x = postprocess(cf_x) cf_x_std = cf_x_std.squeeze().detach().cpu().numpy() rec_x = postprocess(rec_x) cf_r = RACE_CAT[cf_r.argmax(-1)] cf_s = SEX_CAT_CHEST[int(cf_s.item())] cf_f = FIND_CAT[int(cf_f.item())] cf_a = (cf_a.item() + 1) * 50 # plots # plt.close('all') effect = cf_x - rec_x effect = get_fig_arr( effect, cmap="RdBu_r", norm=MidpointNormalize(vmin=effect.min(), midpoint=0, vmax=effect.max()), ) cf_x = get_fig_arr(cf_x) cf_x_std = get_fig_arr(cf_x_std, cmap="jet") return (cf_x, cf_x_std, effect, cf_r, cf_s, cf_f, np.round(cf_a, 1)) with gr.Blocks(theme=gr.themes.Default()) as demo: with gr.Tabs(): with gr.TabItem("Morpho-MNIST") as mnist_tab: mnist_id = gr.Textbox(value=mnist_tab.label, visible=False) with gr.Row().style(equal_height=True): idx = gr.Number(value=0, visible=False) with gr.Column(scale=1, min_width=200): x = gr.Image(label="Observation", interactive=False).style( height=HEIGHT ) with gr.Column(scale=1, min_width=200): cf_x = gr.Image(label="Counterfactual", interactive=False).style( height=HEIGHT ) with gr.Column(scale=1, min_width=200): cf_x_std = gr.Image( label="Counterfactual Uncertainty", interactive=False ).style(height=HEIGHT) with gr.Column(scale=1, min_width=200): effect = gr.Image( label="Direct Causal Effect", interactive=False ).style(height=HEIGHT) with gr.Row().style(equal_height=True): with gr.Column(scale=1.75): gr.Markdown( "**Intervention**" + 20 * " " + "[arXiv paper](https://arxiv.org/abs/2306.15764)   |   [GitHub code](https://github.com/biomedia-mira/causal-gen)" + "  |   Hint: try 90% zoom" ) with gr.Column(): do_y = gr.Checkbox(label="do(digit)", value=False) y = gr.Radio(DIGITS, label="", interactive=False) with gr.Row(): with gr.Column(min_width=100): do_t = gr.Checkbox(label="do(thickness)", value=False) t = gr.Slider( label="\u00A0", minimum=0.9, maximum=5.5, step=0.01, interactive=False, ) with gr.Column(min_width=100): do_i = gr.Checkbox(label="do(intensity)", value=False) i = gr.Slider( label="\u00A0", minimum=50, maximum=255, step=0.01, interactive=False, ) with gr.Row(): new = gr.Button("New Observation") reset = gr.Button("Reset", variant="stop") submit = gr.Button("Submit", variant="primary") with gr.Column(scale=1): gr.Markdown("###  ") causal_graph = gr.Image( label="Causal Graph", interactive=False ).style(height=300) with gr.TabItem("Brain MRI") as brain_tab: brain_id = gr.Textbox(value=brain_tab.label, visible=False) with gr.Row().style(equal_height=True): idx_brain = gr.Number(value=0, visible=False) with gr.Column(scale=1, min_width=200): x_brain = gr.Image(label="Observation", interactive=False).style( height=HEIGHT ) with gr.Column(scale=1, min_width=200): cf_x_brain = gr.Image( label="Counterfactual", interactive=False ).style(height=HEIGHT) with gr.Column(scale=1, min_width=200): cf_x_std_brain = gr.Image( label="Counterfactual Uncertainty", interactive=False ).style(height=HEIGHT) with gr.Column(scale=1, min_width=200): effect_brain = gr.Image( label="Direct Causal Effect", interactive=False ).style(height=HEIGHT) with gr.Row(): with gr.Column(scale=2.55): gr.Markdown( "**Intervention**" + 20 * " " + "[arXiv paper](https://arxiv.org/abs/2306.15764)   |   [GitHub code](https://github.com/biomedia-mira/causal-gen)" + "  |   Hint: try 90% zoom" ) with gr.Row(): with gr.Column(min_width=200): do_m = gr.Checkbox(label="do(MRI sequence)", value=False) m = gr.Radio( ["T1", "T2-FLAIR"], label="", interactive=False ) with gr.Column(min_width=200): do_s = gr.Checkbox(label="do(sex)", value=False) s = gr.Radio( ["female", "male"], label="", interactive=False ) with gr.Row(): with gr.Column(min_width=100): do_a = gr.Checkbox(label="do(age)", value=False) a = gr.Slider( label="\u00A0", value=50, minimum=44, maximum=73, step=1, interactive=False, ) with gr.Column(min_width=100): do_b = gr.Checkbox(label="do(brain volume)", value=False) b = gr.Slider( label="\u00A0", value=1000, minimum=850, maximum=1550, step=20, interactive=False, ) with gr.Column(min_width=100): do_v = gr.Checkbox( label="do(ventricle volume)", value=False ) v = gr.Slider( label="\u00A0", value=40, minimum=10, maximum=125, step=2, interactive=False, ) with gr.Row(): new_brain = gr.Button("New Observation") reset_brain = gr.Button("Reset", variant="stop") submit_brain = gr.Button("Submit", variant="primary") with gr.Column(scale=1): # gr.Markdown("###  ") causal_graph_brain = gr.Image( label="Causal Graph", interactive=False ).style(height=340) with gr.TabItem("Chest X-ray") as chest_tab: chest_id = gr.Textbox(value=chest_tab.label, visible=False) with gr.Row().style(equal_height=True): idx_chest = gr.Number(value=0, visible=False) with gr.Column(scale=1, min_width=200): x_chest = gr.Image(label="Observation", interactive=False).style( height=HEIGHT ) with gr.Column(scale=1, min_width=200): cf_x_chest = gr.Image( label="Counterfactual", interactive=False ).style(height=HEIGHT) with gr.Column(scale=1, min_width=200): cf_x_std_chest = gr.Image( label="Counterfactual Uncertainty", interactive=False ).style(height=HEIGHT) with gr.Column(scale=1, min_width=200): effect_chest = gr.Image( label="Direct Causal Effect", interactive=False ).style(height=HEIGHT) with gr.Row(): with gr.Column(scale=2.55): gr.Markdown( "**Intervention**" + 20 * " " + "[arXiv paper](https://arxiv.org/abs/2306.15764)   |   [GitHub code](https://github.com/biomedia-mira/causal-gen)" + "  |   Hint: try 90% zoom" ) with gr.Row().style(equal_height=True): with gr.Column(min_width=200): do_f_chest = gr.Checkbox(label="do(disease)", value=False) f_chest = gr.Radio(FIND_CAT, label="", interactive=False) with gr.Column(min_width=200): do_s_chest = gr.Checkbox(label="do(sex)", value=False) s_chest = gr.Radio( SEX_CAT_CHEST, label="", interactive=False ) with gr.Row(): with gr.Column(min_width=200): do_r_chest = gr.Checkbox(label="do(race)", value=False) r_chest = gr.Radio(RACE_CAT, label="", interactive=False) with gr.Column(min_width=200): do_a_chest = gr.Checkbox(label="do(age)", value=False) a_chest = gr.Slider( label="\u00A0", minimum=18, maximum=98, step=1 ) with gr.Row(): new_chest = gr.Button("New Observation") reset_chest = gr.Button("Reset", variant="stop") submit_chest = gr.Button("Submit", variant="primary") with gr.Column(scale=1): # gr.Markdown("###  ") causal_graph_chest = gr.Image( label="Causal Graph", interactive=False ).style(height=345) # morphomnist do = [do_t, do_i, do_y] obs = [idx, x, t, i, y] cf_out = [cf_x, cf_x_std, effect] # brain do_brain = [do_m, do_s, do_a, do_b, do_v] # intervention checkboxes obs_brain = [idx_brain, x_brain, m, s, a, b, v] # observed image/attributes cf_out_brain = [cf_x_brain, cf_x_std_brain, effect_brain] # counterfactual outputs # chest do_chest = [do_r_chest, do_s_chest, do_f_chest, do_a_chest] obs_chest = [idx_chest, x_chest, r_chest, s_chest, f_chest, a_chest] cf_out_chest = [cf_x_chest, cf_x_std_chest, effect_chest] # on start: load new observations & causal graph demo.load(fn=get_mnist_obs, inputs=None, outputs=obs) demo.load(fn=mnist_graph, inputs=do, outputs=causal_graph) demo.load(fn=load_model, inputs=mnist_id, outputs=None) demo.load(fn=get_brain_obs, inputs=None, outputs=obs_brain) demo.load(fn=get_chest_obs, inputs=None, outputs=obs_chest) demo.load(fn=brain_graph, inputs=do_brain, outputs=causal_graph_brain) demo.load(fn=chest_graph, inputs=do_chest, outputs=causal_graph_chest) # on tab select: load models brain_tab.select(fn=load_model, inputs=brain_id, outputs=None) chest_tab.select(fn=load_model, inputs=chest_id, outputs=None) # "new" button: load new observations new.click(fn=get_mnist_obs, inputs=None, outputs=obs) new_chest.click(fn=get_chest_obs, inputs=None, outputs=obs_chest) new_brain.click(fn=get_brain_obs, inputs=None, outputs=obs_brain) # "new" button: reset causal graphs new.click(fn=mnist_graph, inputs=do, outputs=causal_graph) new_brain.click(fn=brain_graph, inputs=do_brain, outputs=causal_graph_brain) new_chest.click(fn=chest_graph, inputs=do_chest, outputs=causal_graph_chest) # "new" button: reset cf output panels for _k, _v in zip( [new, new_brain, new_chest], [cf_out, cf_out_brain, cf_out_chest] ): _k.click(fn=lambda: (gr.update(value=None),) * 3, inputs=None, outputs=_v) # "reset" button: reload current observations reset.click(fn=get_mnist_obs, inputs=idx, outputs=obs) reset_brain.click(fn=get_brain_obs, inputs=idx_brain, outputs=obs_brain) reset_chest.click(fn=get_chest_obs, inputs=idx_chest, outputs=obs_chest) # "reset" button: deselect intervention checkboxes reset.click(fn=lambda: (gr.update(value=False),) * len(do), inputs=None, outputs=do) reset_brain.click( fn=lambda: (gr.update(value=False),) * len(do_brain), inputs=None, outputs=do_brain, ) reset_chest.click( fn=lambda: (gr.update(value=False),) * len(do_chest), inputs=None, outputs=do_chest, ) # "reset" button: reset cf output panels for _k, _v in zip( [reset, reset_brain, reset_chest], [cf_out, cf_out_brain, cf_out_chest] ): _k.click(fn=lambda: plt.close("all"), inputs=None, outputs=None) _k.click(fn=lambda: (gr.update(value=None),) * 3, inputs=None, outputs=_v) # enable mnist interventions when checkbox is selected & update graph for _k, _v in zip(do, [t, i, y]): _k.change(fn=lambda x: gr.update(interactive=x), inputs=_k, outputs=_v) _k.change(mnist_graph, inputs=do, outputs=causal_graph) # enable brain interventions when checkbox is selected & update graph for _k, _v in zip(do_brain, [m, s, a, b, v]): _k.change(fn=lambda x: gr.update(interactive=x), inputs=_k, outputs=_v) _k.change(brain_graph, inputs=do_brain, outputs=causal_graph_brain) # enable chest interventions when checkbox is selected & update graph for _k, _v in zip(do_chest, [r_chest, s_chest, f_chest, a_chest]): _k.change(fn=lambda x: gr.update(interactive=x), inputs=_k, outputs=_v) _k.change(chest_graph, inputs=do_chest, outputs=causal_graph_chest) # "submit" button: infer countefactuals submit.click(fn=infer_mnist_cf, inputs=obs + do, outputs=cf_out + [t, i, y]) submit_brain.click( fn=infer_brain_cf, inputs=obs_brain + do_brain, outputs=cf_out_brain + [m, s, a, b, v], ) submit_chest.click( fn=infer_chest_cf, inputs=obs_chest + do_chest, outputs=cf_out_chest + [r_chest, s_chest, f_chest, a_chest], ) if __name__ == "__main__": demo.queue() demo.launch()