Spaces:
Sleeping
Sleeping
import numpy as np | |
import torch | |
import model | |
import helper | |
import MyDataSet | |
import model_probit | |
import gradio as gr | |
lattent_2class = model_probit.latent_space() | |
lattent_2class = helper.import_model_name(model_x=lattent_2class, activate_eval=True) | |
model_vae = model.VaeFinal_only_one_hidden_copy() | |
model_vae = helper.import_model_name_weights_copy(model_x=model_vae, activate_eval=True) | |
def ret_w_w0(): | |
iterweights = iter(lattent_2class.parameters()) | |
w = next(iterweights).data[0] | |
w0 = next(iterweights).data[0] | |
return w, w0 | |
with torch.no_grad(): | |
dataset_49 = MyDataSet.MyDataSets_Subset_4_9(batch_size_train=-1) | |
img_49_batch, label_49_batch = next(iter(dataset_49.train_loader_subset_changed_labels)) | |
rec49, mu49, sigma49 = model_vae(img_49_batch.clone()) | |
z = mu49 | |
w, w0 = ret_w_w0() | |
def return_counter_with_steps(z, steps=20, w=w, w0=w0): | |
STEPS = steps | |
alpha_i = - (torch.t(z) * w + w0) / torch.t(w) * w | |
z_counter = z + STEPS * alpha_i * w | |
with torch.no_grad(): | |
z_counter_recons_img = model_vae.decode(z_counter) | |
return z_counter_recons_img.view(28, 28) | |
def predict(index, steps): | |
print(index, steps) | |
index = int(index) | |
steps = int(steps) | |
z_counter_recons_img = return_counter_with_steps(z=z[index], steps=steps) | |
counter_image = z_counter_recons_img | |
counter_image = counter_image.clamp(0, 1) | |
# print(f'{counter_image = }') | |
return return_image_in_format(counter_image) | |
def return_image_in_format(counter_image): | |
counter_image = np.asarray(counter_image) | |
return counter_image | |
def show_counterfactual(index, steps): | |
outputs = gr.Image(image_mode='L', shape=(28, 28), value=np.asarray(predict(index, steps))), | |
return outputs | |
demo = gr.Interface( | |
#live=True, | |
fn=predict, | |
inputs=[gr.Number(label="index of z"), gr.Number(label="steps towards border")], | |
outputs='image', | |
title="Counterfactuals", | |
description="pick an image, show reconstruction, steps for counterfactuals!", | |
) | |
demo.launch() | |