prjection as state id
import streamlit as st
import pickle
import pandas as pd
import numpy as np
import random
import torch
from matplotlib.backends.backend_agg import RendererAgg
from backend.disentangle_concepts import *
import torch_utils
import dnnlib
import legacy
st.title('Disentanglement studies')
st.write('> **What concepts can be disentangled in the latent spae of a model?**')
instruction_text = """Instruction to input:
1. Choosing concept:
2. Choosing image: Users can choose a specific image by entering **Image ID** and hit the _Choose the defined image_ button or can generate an image randomly by hitting the _Generate a random image_ button.
3. Choosing epsilon: **Epsilon** is the lambda amount of translation along the disentangled concept axis. A negative epsilon changes the image in the direction of the concept, a positive one pushes the image away from the concept.
st.write("To use the functionality below, users need to input the **concept** to disentangle, an **image** id and the **epsilon** of variation along the disentangled axis.")
with st.expander("See more instruction", expanded=False):
annotations_file = './data/annotated_files/seeds0000-50000.pkl'
with open(annotations_file, 'rb') as f:
annotations = pickle.load(f)
ann_df = pd.read_csv('./data/annotated_files/sim_seeds0000-50000.csv')
concepts = './data/concepts.txt'
with open(concepts) as f:
labels = [line.strip() for line in f.readlines()]
if 'image_id' not in st.session_state:
st.session_state.image_id = 0
if 'projection' not in st.session_state:
st.session_state.projection = False
if 'concept_id' not in st.session_state:
st.session_state.concept_id = 'Abstract'
if 'space_id' not in st.session_state:
st.session_state.space_id = 'Z'
# ----------------------------- INPUT ----------------------------------
input_col_1, input_col_2, input_col_3 = st.columns(3)
# --------------------------- INPUT column 1 ---------------------------
with input_col_1:
with st.form('text_form'):
st.write('**Choose a concept to disentangle**')
concept_id = st.selectbox('Concept:', tuple(labels))
st.write('**Choose a latent space to disentangle**')
space_id = st.selectbox('Space:', tuple(['Z', 'W']))
choose_text_button = st.form_submit_button('Choose the defined concept and space to disentangle')
if choose_text_button:
concept_id = str(concept_id)
st.session_state.concept_id = concept_id
space_id = str(space_id)
st.session_state.space_id = space_id
# ---------------------------- SET UP OUTPUT ------------------------------
epsilon_container = st.empty()
st.subheader('Concept vector')
# perform attack container
header_col_1, header_col_2 = st.columns([5,1])
output_col_1, output_col_2 = st.columns([5,1])
st.subheader('Derivations along the concept vector')
# prediction error container
error_container = st.empty()
smoothgrad_header_container = st.empty()
# smoothgrad container
smooth_head_1, smooth_head_2, smooth_head_3, smooth_head_4, smooth_head_5 = st.columns([1,1,1,1,1])
smoothgrad_col_1, smoothgrad_col_2, smoothgrad_col_3, smoothgrad_col_4, smoothgrad_col_5 = st.columns([1,1,1,1,1])
# ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------
with output_col_1:
separation_vector, number_important_features, imp_nodes, performance = get_separation_space(concept_id, annotations, ann_df, latent_space=st.session_state.space_id)
# st.write(f'Class ID {input_id} - {input_label}: {pred_prob*100:.3f}% confidence')
st.write('Concept vector', separation_vector)
header_col_1.write(f'Concept {concept_id} - Space {st.session_state.space_id} - Number of relevant nodes: {number_important_features} - Val classification performance: {performance}')# - Nodes {",".join(list(imp_nodes))}')
# ----------------------------- INPUT column 2 & 3 ----------------------------
with input_col_2:
with st.form('image_form'):
# image_id = st.number_input('Image ID: ', format='%d', step=1)
st.write('**Choose or generate a random image to test the disentanglement**')
chosen_image_id_input = st.empty()
image_id = chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
choose_image_button = st.form_submit_button('Choose the defined image')
random_id = st.form_submit_button('Generate a random image')
projection_id = st.form_submit_button('Generate an image on the boudary')
if random_id or projection_id:
image_id = random.randint(0, 50000)
st.session_state.image_id = image_id
chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
st.session_state.projection = False
if projection_id:
st.session_state.projection = True
if choose_image_button:
image_id = int(image_id)
st.session_state.image_id = int(image_id)
with input_col_3:
with st.form('Variate along the disentangled concept'):
st.write('**Set range of change**')
chosen_epsilon_input = st.empty()
epsilon = chosen_epsilon_input.number_input('Lambda:', min_value=1, step=1)
epsilon_button = st.form_submit_button('Choose the defined lambda')
st.write('**Select hierarchical levels to manipulate**')
layers = st.multiselect('Layers:', tuple(range(14)))
if len(layers) == 0:
layers = None
layers_button = st.form_submit_button('Choose the defined layers')
# ---------------------------- DISPLAY COL 2 ROW 1 ------------------------------
#model = torch.load('./data/model_files/pytorch_model.bin', map_location=torch.device('cpu'))
with dnnlib.util.open_url('./data/model_files/network-snapshot-010600.pkl') as f:
model = legacy.load_network_pkl(f)['G_ema'].to('cpu') # type: ignore
if st.session_state.space_id == 'Z':
original_image_vec = annotations['z_vectors'][st.session_state.image_id]
original_image_vec = annotations['w_vectors'][st.session_state.image_id]
if st.session_state.projection:
original_image_vec = original_image_vec -, separation_vector) * separation_vector
img = generate_original_image(original_image_vec, model, latent_space=st.session_state.space_id)
with smoothgrad_col_3:
smooth_head_3.write(f'Base image')
images, lambdas = regenerate_images(model, original_image_vec, separation_vector, min_epsilon=-(int(epsilon)), max_epsilon=int(epsilon), latent_space=st.session_state.space_id, layers=layers)
with smoothgrad_col_1:
smooth_head_1.write(f'Change of {np.round(lambdas[0], 2)}')
with smoothgrad_col_2:
smooth_head_2.write(f'Change of {np.round(lambdas[1], 2)}')
with smoothgrad_col_4:
smooth_head_4.write(f'Change of {np.round(lambdas[3], 2)}')
with smoothgrad_col_5:
smooth_head_5.write(f'Change of {np.round(lambdas[4], 2)}')