mnist_counter / app.py
somethingbyai's picture
Update app.py
b1688f8
raw
history blame contribute delete
No virus
2.04 kB
import numpy as np
import torch
import model
import helper
import MyDataSet
import model_probit
import gradio as gr
lattent_2class = model_probit.latent_space()
lattent_2class = helper.import_model_name(model_x=lattent_2class, activate_eval=True)
model_vae = model.VaeFinal_only_one_hidden_copy()
model_vae = helper.import_model_name_weights_copy(model_x=model_vae, activate_eval=True)
def ret_w_w0():
iterweights = iter(lattent_2class.parameters())
w = next(iterweights).data[0]
w0 = next(iterweights).data[0]
return w, w0
with torch.no_grad():
dataset_49 = MyDataSet.MyDataSets_Subset_4_9(batch_size_train=-1)
img_49_batch, label_49_batch = next(iter(dataset_49.train_loader_subset_changed_labels))
rec49, mu49, sigma49 = model_vae(img_49_batch.clone())
z = mu49
w, w0 = ret_w_w0()
def return_counter_with_steps(z, steps=20, w=w, w0=w0):
STEPS = steps
alpha_i = - (torch.t(z) * w + w0) / torch.t(w) * w
z_counter = z + STEPS * alpha_i * w
with torch.no_grad():
z_counter_recons_img = model_vae.decode(z_counter)
return z_counter_recons_img.view(28, 28)
def predict(index, steps):
print(index, steps)
index = int(index)
steps = int(steps)
z_counter_recons_img = return_counter_with_steps(z=z[index], steps=steps)
counter_image = z_counter_recons_img
counter_image = counter_image.clamp(0, 1)
# print(f'{counter_image = }')
return return_image_in_format(counter_image)
def return_image_in_format(counter_image):
counter_image = np.asarray(counter_image)
return counter_image
def show_counterfactual(index, steps):
outputs = gr.Image(image_mode='L', shape=(28, 28), value=np.asarray(predict(index, steps))),
return outputs
demo = gr.Interface(
#live=True,
fn=predict,
inputs=[gr.Number(label="index of z"), gr.Number(label="steps towards border")],
outputs='image',
title="Counterfactuals",
description="pick an image, show reconstruction, steps for counterfactuals!",
)
demo.launch()