latent-space-theories / pages /3_Vectors_algebra.py
ludusc's picture
clean up
1cb3b5a
raw
history blame
No virus
6.76 kB
import streamlit as st
import pickle
import pandas as pd
import numpy as np
import random
import torch
from matplotlib.backends.backend_agg import RendererAgg
from backend.disentangle_concepts import *
import torch_utils
import dnnlib
import legacy
_lock = RendererAgg.lock
st.set_page_config(layout='wide')
BACKGROUND_COLOR = '#bcd0e7'
SECONDARY_COLOR = '#bce7db'
st.title('Vector algebra using disentangled vectors')
st.markdown(
"""
This page offers the possibility to edit the colors of a given textile image using vector algebra and projections.
It allows to select several colors to move towards and against (selecting a positive or negative lambda).
Furthermore, it offers the possibility of conditional manipulation, by moving in the direction of a color n1 without affecting the color n2.
This is done using a projected direction n1 - (n1.T n2) n2.
""",
unsafe_allow_html=False,)
annotations_file = './data/textile_annotated_files/seeds0000-100000_S.pkl'
with open(annotations_file, 'rb') as f:
annotations = pickle.load(f)
concept_vectors = pd.read_csv('./data/stored_vectors/scores_colors_hsv.csv')
concept_vectors['vector'] = [np.array([float(xx) for xx in x]) for x in concept_vectors['vector'].str.split(', ')]
concept_vectors['score'] = concept_vectors['score'].astype(float)
concept_vectors = concept_vectors.sort_values('score', ascending=False).reset_index()
with dnnlib.util.open_url('./data/textile_model_files/network-snapshot-005000.pkl') as f:
model = legacy.load_network_pkl(f)['G_ema'].to('cpu') # type: ignore
COLORS_LIST = ['Gray', 'Red Orange', 'Yellow', 'Green', 'Light Blue', 'Blue', 'Purple', 'Pink', 'Saturation', 'Value']
COLORS_NEGATIVE = COLORS_LIST + ['None']
if 'image_id' not in st.session_state:
st.session_state.image_id = 52921
if 'colors' not in st.session_state:
st.session_state.colors = [COLORS_LIST[5], COLORS_LIST[7]]
if 'non_colors' not in st.session_state:
st.session_state.non_colors = ['None']
if 'color_lambda' not in st.session_state:
st.session_state.color_lambda = [5]
# ----------------------------- INPUT ----------------------------------
epsilon_container = st.empty()
st.header('Image Manipulation with Vector Algebra')
header_col_1, header_col_2, header_col_3, header_col_4 = st.columns([1,1,1,1])
input_col_1, output_col_2, output_col_3, input_col_4 = st.columns([1,1,1,1])
# --------------------------- INPUT column 1 ---------------------------
with input_col_1:
with st.form('image_form'):
# image_id = st.number_input('Image ID: ', format='%d', step=1)
st.write('**Choose or generate a random image**')
chosen_image_id_input = st.empty()
image_id = chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
choose_image_button = st.form_submit_button('Choose the defined image')
random_id = st.form_submit_button('Generate a random image')
if random_id:
image_id = random.randint(0, 100000)
st.session_state.image_id = image_id
chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
if choose_image_button:
image_id = int(image_id)
st.session_state.image_id = image_id
with header_col_1:
st.write('### Input image selection')
original_image_vec = annotations['w_vectors'][st.session_state.image_id]
img = generate_original_image(original_image_vec, model)
with output_col_2:
st.image(img)
with header_col_2:
st.write('### Original image')
with input_col_4:
with st.form('text_form_1'):
st.write('**Colors to vary (including Saturation and Value)**')
colors = st.multiselect('Color:', tuple(COLORS_LIST), default=[COLORS_LIST[5], COLORS_LIST[7]])
colors_button = st.form_submit_button('Choose the defined colors')
st.session_state.image_id = image_id
st.session_state.colors = colors
st.session_state.color_lambda = [5]*len(colors)
st.session_state.non_colors = ['None']*len(colors)
lambdas = []
negative_cols = []
for color in colors:
st.write('### '+color )
st.write('**Set range of change (can be negative)**')
chosen_color_lambda_input = st.empty()
color_lambda = chosen_color_lambda_input.number_input('Lambda:', min_value=-100, step=1, value=5, key=color+'_number')
lambdas.append(color_lambda)
st.write('**Set dimensions of change to not consider**')
chosen_color_negative_input = st.empty()
color_negative = chosen_color_negative_input.selectbox('Color:', tuple(COLORS_NEGATIVE), index=len(COLORS_NEGATIVE)-1, key=color+'_noncolor')
negative_cols.append(color_negative)
lambdas_button = st.form_submit_button('Submit options')
if lambdas_button:
st.session_state.color_lambda = lambdas
st.session_state.non_colors = negative_cols
with header_col_4:
st.write('### Color settings')
# print(st.session_state.colors)
# print(st.session_state.color_lambda)
# print(st.session_state.non_colors)
# ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------
with header_col_3:
separation_vectors = []
for col in st.session_state.colors:
separation_vector, score_1 = concept_vectors[concept_vectors['color'] == col].reset_index().loc[0, ['vector', 'score']]
separation_vectors.append(separation_vector)
negative_separation_vectors = []
for non_col in st.session_state.non_colors:
if non_col != 'None':
negative_separation_vector, score_2 = concept_vectors[concept_vectors['color'] == non_col].reset_index().loc[0, ['vector', 'score']]
negative_separation_vectors.append(negative_separation_vector)
else:
negative_separation_vectors.append('None')
## n1 − (n1T n2)n2
# print(negative_separation_vectors, separation_vectors)
st.write('### Output Image')
st.write(f'''Change in colors: {str(st.session_state.colors)},\
without affecting colors {str(st.session_state.non_colors)}''')
# ---------------------------- DISPLAY COL 2 ROW 1 ------------------------------
with output_col_3:
image_updated = generate_composite_images(model, original_image_vec, separation_vectors,
lambdas=st.session_state.color_lambda,
negative_colors=negative_separation_vectors)
st.image(image_updated)