File size: 6,762 Bytes
e79558d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1cb3b5a
e79558d
 
1cb3b5a
 
 
 
e79558d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1cb3b5a
e79558d
 
 
1cb3b5a
e79558d
 
 
 
 
1cb3b5a
 
e79558d
 
 
 
 
 
1cb3b5a
e79558d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1cb3b5a
e79558d
1cb3b5a
 
e79558d
 
 
 
 
1cb3b5a
e79558d
 
 
 
1cb3b5a
 
e79558d
 
 
 
 
 
 
 
 
 
1cb3b5a
 
e79558d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1cb3b5a
 
e79558d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1cb3b5a
 
 
e79558d
 
 
 
1cb3b5a
 
 
e79558d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import streamlit as st
import pickle
import pandas as pd
import numpy as np
import random
import torch

from matplotlib.backends.backend_agg import RendererAgg

from backend.disentangle_concepts import *
import torch_utils
import dnnlib
import legacy

_lock = RendererAgg.lock


st.set_page_config(layout='wide')
BACKGROUND_COLOR = '#bcd0e7'
SECONDARY_COLOR = '#bce7db'


st.title('Vector algebra using disentangled vectors')
st.markdown(
    """
    This page offers the possibility to edit the colors of a given textile image using vector algebra and projections.
    It allows to select several colors to move towards and against (selecting a positive or negative lambda). 
    Furthermore, it offers the possibility of conditional manipulation, by moving in the direction of a color n1 without affecting the color n2. 
    This is done using a projected direction n1 - (n1.T n2) n2. 
    """,
    unsafe_allow_html=False,)
    
annotations_file = './data/textile_annotated_files/seeds0000-100000_S.pkl'
with open(annotations_file, 'rb') as f:
    annotations = pickle.load(f)

concept_vectors = pd.read_csv('./data/stored_vectors/scores_colors_hsv.csv')
concept_vectors['vector'] = [np.array([float(xx) for xx in x]) for x in concept_vectors['vector'].str.split(', ')]
concept_vectors['score'] = concept_vectors['score'].astype(float)

concept_vectors = concept_vectors.sort_values('score', ascending=False).reset_index()

with dnnlib.util.open_url('./data/textile_model_files/network-snapshot-005000.pkl') as f:
    model = legacy.load_network_pkl(f)['G_ema'].to('cpu') # type: ignore

COLORS_LIST = ['Gray', 'Red Orange', 'Yellow', 'Green', 'Light Blue', 'Blue', 'Purple', 'Pink', 'Saturation', 'Value']
COLORS_NEGATIVE = COLORS_LIST + ['None']

if 'image_id' not in st.session_state:
    st.session_state.image_id = 52921
if 'colors' not in st.session_state:
    st.session_state.colors = [COLORS_LIST[5], COLORS_LIST[7]]
if 'non_colors' not in st.session_state:
    st.session_state.non_colors = ['None']
if 'color_lambda' not in st.session_state:
    st.session_state.color_lambda = [5]

# ----------------------------- INPUT ----------------------------------
epsilon_container = st.empty()
st.header('Image Manipulation with Vector Algebra')

header_col_1, header_col_2, header_col_3, header_col_4 = st.columns([1,1,1,1])
input_col_1, output_col_2, output_col_3, input_col_4 = st.columns([1,1,1,1])

# --------------------------- INPUT column 1 ---------------------------
with input_col_1:
   with st.form('image_form'):
        
        # image_id = st.number_input('Image ID: ', format='%d', step=1)
        st.write('**Choose or generate a random image**')
        chosen_image_id_input = st.empty()
        image_id = chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
        
        choose_image_button = st.form_submit_button('Choose the defined image')
        random_id = st.form_submit_button('Generate a random image')

        if random_id:
            image_id = random.randint(0, 100000)
            st.session_state.image_id = image_id
            chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
            
        if choose_image_button:
            image_id = int(image_id)
            st.session_state.image_id = image_id

with header_col_1:
    st.write('### Input image selection')

original_image_vec = annotations['w_vectors'][st.session_state.image_id]
img = generate_original_image(original_image_vec, model)

with output_col_2:
    st.image(img)

with header_col_2:
    st.write('### Original image')
    
with input_col_4:
    with st.form('text_form_1'):
        
        st.write('**Colors to vary (including Saturation and Value)**')
        colors = st.multiselect('Color:', tuple(COLORS_LIST), default=[COLORS_LIST[5], COLORS_LIST[7]])
        colors_button = st.form_submit_button('Choose the defined colors')
        
        st.session_state.image_id = image_id
        st.session_state.colors = colors
        st.session_state.color_lambda = [5]*len(colors)
        st.session_state.non_colors = ['None']*len(colors)
            
        lambdas = []
        negative_cols = []
        for color in colors:
            st.write('### '+color )
            st.write('**Set range of change (can be negative)**')
            chosen_color_lambda_input = st.empty()
            color_lambda = chosen_color_lambda_input.number_input('Lambda:', min_value=-100, step=1, value=5, key=color+'_number')
            lambdas.append(color_lambda)
                
            st.write('**Set dimensions of change to not consider**')
            chosen_color_negative_input = st.empty()
            color_negative = chosen_color_negative_input.selectbox('Color:', tuple(COLORS_NEGATIVE), index=len(COLORS_NEGATIVE)-1, key=color+'_noncolor')
            negative_cols.append(color_negative)
                
        lambdas_button = st.form_submit_button('Submit options')
        if lambdas_button:
            st.session_state.color_lambda = lambdas
            st.session_state.non_colors = negative_cols
                
        
with header_col_4:
    st.write('### Color settings')
# print(st.session_state.colors)
# print(st.session_state.color_lambda)
# print(st.session_state.non_colors)

# ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------

with header_col_3:
    separation_vectors = []
    for col in st.session_state.colors:
        separation_vector, score_1 = concept_vectors[concept_vectors['color'] == col].reset_index().loc[0, ['vector', 'score']]
        separation_vectors.append(separation_vector)
    
    negative_separation_vectors = []
    for non_col in st.session_state.non_colors:
        if non_col != 'None':
            negative_separation_vector, score_2 = concept_vectors[concept_vectors['color'] == non_col].reset_index().loc[0, ['vector', 'score']]
            negative_separation_vectors.append(negative_separation_vector)
        else:
            negative_separation_vectors.append('None')
    ## n1 − (n1T n2)n2
    # print(negative_separation_vectors, separation_vectors)
    st.write('### Output Image')
    st.write(f'''Change in colors: {str(st.session_state.colors)},\
             without affecting colors {str(st.session_state.non_colors)}''')

# ---------------------------- DISPLAY COL 2 ROW 1 ------------------------------

with output_col_3:
    image_updated = generate_composite_images(model, original_image_vec, separation_vectors, 
                                              lambdas=st.session_state.color_lambda, 
                                              negative_colors=negative_separation_vectors)
    st.image(image_updated)