# Copyright 2020 Erik Härkönen. All rights reserved. # This file is licensed to you under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. You may obtain a copy # of the License at http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software distributed under # the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS # OF ANY KIND, either express or implied. See the License for the specific language # governing permissions and limitations under the License. # An interactive glumpy (OpenGL) + tkinter viewer for interacting with principal components. # Requires OpenGL and CUDA support for rendering. import torch import numpy as np import tkinter as tk from tkinter import ttk from types import SimpleNamespace import matplotlib.pyplot as plt from pathlib import Path from os import makedirs from models import get_instrumented_model from config import Config from decomposition import get_or_compute from torch.nn.functional import interpolate from TkTorchWindow import TorchImageView from functools import partial from platform import system from PIL import Image from utils import pad_frames, prettify_name import pickle # For platform specific UI tweaks is_windows = 'Windows' in system() is_linux = 'Linux' in system() is_mac = 'Darwin' in system() # Read input parameters args = Config().from_args() # Don't bother without GPU assert torch.cuda.is_available(), 'Interactive mode requires CUDA' # Use syntax from paper def get_edit_name(idx, s, e, name=None): return 'E({comp}, {edit_range}){edit_name}'.format( comp = idx, edit_range = f'{s}-{e}' if e > s else s, edit_name = f': {name}' if name else '' ) # Load or compute PCA basis vectors def load_components(class_name, inst): global components, state, use_named_latents config = args.from_dict({ 'output_class': class_name }) dump_name = get_or_compute(config, inst) data = np.load(dump_name, allow_pickle=False) X_comp = data['act_comp'] X_mean = data['act_mean'] X_stdev = data['act_stdev'] Z_comp = data['lat_comp'] Z_mean = data['lat_mean'] Z_stdev = data['lat_stdev'] random_stdev_act = np.mean(data['random_stdevs']) n_comp = X_comp.shape[0] data.close() # Transfer to GPU components = SimpleNamespace( X_comp = torch.from_numpy(X_comp).cuda().float(), X_mean = torch.from_numpy(X_mean).cuda().float(), X_stdev = torch.from_numpy(X_stdev).cuda().float(), Z_comp = torch.from_numpy(Z_comp).cuda().float(), Z_stdev = torch.from_numpy(Z_stdev).cuda().float(), Z_mean = torch.from_numpy(Z_mean).cuda().float(), names = [f'Component {i}' for i in range(n_comp)], latent_types = [model.latent_space_name()]*n_comp, ranges = [(0, model.get_max_latents())]*n_comp, ) state.component_class = class_name # invalidates cache use_named_latents = False print('Loaded components for', class_name, 'from', dump_name) # Load previously exported named components from # directory specified with '--inputs=path/to/comp' def load_named_components(path, class_name): global components, state, use_named_latents import glob matches = glob.glob(f'{path}/*.pkl') selected = [] for dump_path in matches: with open(dump_path, 'rb') as f: data = pickle.load(f) if data['model_name'] != model_name or data['output_class'] != class_name: continue if data['latent_space'] != model.latent_space_name(): print('Skipping', dump_path, '(wrong latent space)') continue selected.append(data) print('Using', dump_path) if len(selected) == 0: raise RuntimeError('No valid components in given path.') comp_dict = { k : [] for k in ['X_comp', 'Z_comp', 'X_stdev', 'Z_stdev', 'names', 'types', 'layer_names', 'ranges', 'latent_types'] } components = SimpleNamespace(**comp_dict) for d in selected: s = d['edit_start'] e = d['edit_end'] title = get_edit_name(d['component_index'], s, e - 1, d['name']) # show inclusive components.X_comp.append(torch.from_numpy(d['act_comp']).cuda()) components.Z_comp.append(torch.from_numpy(d['lat_comp']).cuda()) components.X_stdev.append(d['act_stdev']) components.Z_stdev.append(d['lat_stdev']) components.names.append(title) components.types.append(d['edit_type']) components.layer_names.append(d['decomposition']['layer']) # only for act components.ranges.append((s, e)) components.latent_types.append(d['latent_space']) # W or Z use_named_latents = True print('Loaded named components') def setup_model(): global model, inst, layer_name, model_name, feat_shape, args, class_name model_name = args.model layer_name = args.layer class_name = args.output_class # Speed up pytorch torch.autograd.set_grad_enabled(False) torch.backends.cudnn.benchmark = True # Load model inst = get_instrumented_model(model_name, class_name, layer_name, torch.device('cuda'), use_w=args.use_w) model = inst.model feat_shape = inst.feature_shape[layer_name] sample_dims = np.prod(feat_shape) # Initialize if args.inputs: load_named_components(args.inputs, class_name) else: load_components(class_name, inst) # Project tensor 'X' onto orthonormal basis 'comp', return coordinates def project_ortho(X, comp): N = comp.shape[0] coords = (comp.reshape(N, -1) * X.reshape(-1)).sum(dim=1) return coords.reshape([N]+[1]*X.ndim) def zero_sliders(): for v in ui_state.sliders: v.set(0.0) def reset_sliders(zero_on_failure=True): global ui_state mode = ui_state.mode.get() # Not orthogonal: need to solve least-norm problem # Not batch size 1: one set of sliders not enough # Not principal components: unsupported format is_ortho = not (mode == 'latent' and model.latent_space_name() == 'Z') is_single = state.z.shape[0] == 1 is_pcs = not use_named_latents state.lat_slider_offset = 0 state.act_slider_offset = 0 enabled = False if not (enabled and is_ortho and is_single and is_pcs): if zero_on_failure: zero_sliders() return if mode == 'activation': val = state.base_act mean = components.X_mean comp = components.X_comp stdev = components.X_stdev else: val = state.z mean = components.Z_mean comp = components.Z_comp stdev = components.Z_stdev n_sliders = len(ui_state.sliders) coords = project_ortho(val - mean, comp) offset = torch.sum(coords[:n_sliders] * comp[:n_sliders], dim=0) scaled_coords = (coords.view(-1) / stdev).detach().cpu().numpy() # Part representable by sliders if mode == 'activation': state.act_slider_offset = offset else: state.lat_slider_offset = offset for i in range(n_sliders): ui_state.sliders[i].set(round(scaled_coords[i], ndigits=1)) def setup_ui(): global root, toolbar, ui_state, app, canvas root = tk.Tk() scale = 1.0 app = TorchImageView(root, width=int(scale*1024), height=int(scale*1024), show_fps=False) app.pack(fill=tk.BOTH, expand=tk.YES) root.protocol("WM_DELETE_WINDOW", shutdown) root.title('GANspace') toolbar = tk.Toplevel(root) toolbar.protocol("WM_DELETE_WINDOW", shutdown) toolbar.geometry("215x800+0+0") toolbar.title('') N_COMPONENTS = min(70, len(components.names)) ui_state = SimpleNamespace( sliders = [tk.DoubleVar(value=0.0) for _ in range(N_COMPONENTS)], scales = [], truncation = tk.DoubleVar(value=0.9), outclass = tk.StringVar(value=class_name), random_seed = tk.StringVar(value='0'), mode = tk.StringVar(value='latent'), batch_size = tk.IntVar(value=1), # how many images to show in window edit_layer_start = tk.IntVar(value=0), edit_layer_end = tk.IntVar(value=model.get_max_latents() - 1), slider_max_val = 10.0 ) # Z vs activation mode button #tk.Radiobutton(toolbar, text=f"Latent ({model.latent_space_name()})", variable=ui_state.mode, command=reset_sliders, value='latent').pack(fill="x") #tk.Radiobutton(toolbar, text="Activation", variable=ui_state.mode, command=reset_sliders, value='activation').pack(fill="x") # Choose range where latents are modified def set_min(val): ui_state.edit_layer_start.set(min(int(val), ui_state.edit_layer_end.get())) def set_max(val): ui_state.edit_layer_end.set(max(int(val), ui_state.edit_layer_start.get())) max_latent_idx = model.get_max_latents() - 1 if not use_named_latents: slider_min = tk.Scale(toolbar, command=set_min, variable=ui_state.edit_layer_start, label='Layer start', from_=0, to=max_latent_idx, orient=tk.HORIZONTAL).pack(fill="x") slider_max = tk.Scale(toolbar, command=set_max, variable=ui_state.edit_layer_end, label='Layer end', from_=0, to=max_latent_idx, orient=tk.HORIZONTAL).pack(fill="x") # Scrollable list of components outer_frame = tk.Frame(toolbar, borderwidth=2, relief=tk.SUNKEN) canvas = tk.Canvas(outer_frame, highlightthickness=0, borderwidth=0) frame = tk.Frame(canvas) vsb = tk.Scrollbar(outer_frame, orient="vertical", command=canvas.yview) canvas.configure(yscrollcommand=vsb.set) vsb.pack(side="right", fill="y") canvas.pack(side="left", fill="both", expand=True) canvas.create_window((4,4), window=frame, anchor="nw") def onCanvasConfigure(event): canvas.itemconfigure("all", width=event.width) canvas.configure(scrollregion=canvas.bbox("all")) canvas.bind("", onCanvasConfigure) def on_scroll(event): delta = 1 if (event.num == 5 or event.delta < 0) else -1 canvas.yview_scroll(delta, "units") canvas.bind_all("", on_scroll) canvas.bind_all("", on_scroll) canvas.bind_all("", on_scroll) canvas.bind_all("", lambda event : handle_keypress(event.keysym_num)) # Sliders and buttons for i in range(N_COMPONENTS): inner = tk.Frame(frame, borderwidth=1, background="#aaaaaa") scale = tk.Scale(inner, variable=ui_state.sliders[i], from_=-ui_state.slider_max_val, to=ui_state.slider_max_val, resolution=0.1, orient=tk.HORIZONTAL, label=components.names[i]) scale.pack(fill=tk.X, side=tk.LEFT, expand=True) ui_state.scales.append(scale) # for changing label later if not use_named_latents: tk.Button(inner, text=f"Save", command=partial(export_direction, i, inner)).pack(fill=tk.Y, side=tk.RIGHT) inner.pack(fill=tk.X) outer_frame.pack(fill="both", expand=True, pady=0) tk.Button(toolbar, text="Reset", command=reset_sliders).pack(anchor=tk.CENTER, fill=tk.X, padx=4, pady=4) tk.Scale(toolbar, variable=ui_state.truncation, from_=0.01, to=1.0, resolution=0.01, orient=tk.HORIZONTAL, label='Truncation').pack(fill="x") tk.Scale(toolbar, variable=ui_state.batch_size, from_=1, to=9, resolution=1, orient=tk.HORIZONTAL, label='Batch size').pack(fill="x") # Output class frame = tk.Frame(toolbar) tk.Label(frame, text="Class name").pack(fill="x", side="left") tk.Entry(frame, textvariable=ui_state.outclass).pack(fill="x", side="right", expand=True, padx=5) frame.pack(fill=tk.X, pady=3) # Random seed def update_seed(): seed_str = ui_state.random_seed.get() if seed_str.isdigit(): resample_latent(int(seed_str)) frame = tk.Frame(toolbar) tk.Label(frame, text="Seed").pack(fill="x", side="left") tk.Entry(frame, textvariable=ui_state.random_seed, width=12).pack(fill="x", side="left", expand=True, padx=2) tk.Button(frame, text="Update", command=update_seed).pack(fill="y", side="right", padx=3) frame.pack(fill=tk.X, pady=3) # Get new latent or new components tk.Button(toolbar, text="Resample latent", command=partial(resample_latent, None, False)).pack(anchor=tk.CENTER, fill=tk.X, padx=4, pady=4) #tk.Button(toolbar, text="Recompute", command=recompute_components).pack(anchor=tk.CENTER, fill=tk.X) # App state state = SimpleNamespace( z=None, # current latent(s) lat_slider_offset = 0, # part of lat that is explained by sliders act_slider_offset = 0, # part of act that is explained by sliders component_class=None, # name of current PCs' image class seed=0, # Latent z_i generated by seed+i base_act = None, # activation of considered layer given z ) def resample_latent(seed=None, only_style=False): class_name = ui_state.outclass.get() if class_name.isnumeric(): class_name = int(class_name) if hasattr(model, 'is_valid_class'): if not model.is_valid_class(class_name): return model.set_output_class(class_name) B = ui_state.batch_size.get() state.seed = np.random.randint(np.iinfo(np.int32).max - B) if seed is None else seed ui_state.random_seed.set(str(state.seed)) # Use consecutive seeds along batch dimension (for easier reproducibility) trunc = ui_state.truncation.get() latents = [model.sample_latent(1, seed=state.seed + i, truncation=trunc) for i in range(B)] state.z = torch.cat(latents).clone().detach() # make leaf node assert state.z.is_leaf, 'Latent is not leaf node!' if hasattr(model, 'truncation'): model.truncation = ui_state.truncation.get() print(f'Seeds: {state.seed} -> {state.seed + B - 1}' if B > 1 else f'Seed: {state.seed}') torch.manual_seed(state.seed) model.partial_forward(state.z, layer_name) state.base_act = inst.retained_features()[layer_name] reset_sliders(zero_on_failure=False) # Remove focus from text entry canvas.focus_set() # Used to recompute after changing class of conditional model def recompute_components(): class_name = ui_state.outclass.get() if class_name.isnumeric(): class_name = int(class_name) if hasattr(model, 'is_valid_class'): if not model.is_valid_class(class_name): return if hasattr(model, 'set_output_class'): model.set_output_class(class_name) load_components(class_name, inst) # Used to detect parameter changes for lazy recomputation class ParamCache(): def update(self, **kwargs): dirty = False for argname, val in kwargs.items(): # Check pointer, then value current = getattr(self, argname, 0) if current is not val and pickle.dumps(current) != pickle.dumps(val): setattr(self, argname, val) dirty = True return dirty cache = ParamCache() def l2norm(t): return torch.norm(t.view(t.shape[0], -1), p=2, dim=1, keepdim=True) def apply_edit(z0, delta): return z0 + delta def reposition_toolbar(): size, X, Y = root.winfo_geometry().split('+') W, H = size.split('x') toolbar_W = toolbar.winfo_geometry().split('x')[0] offset_y = -30 if is_linux else 0 # window title bar toolbar.geometry(f'{toolbar_W}x{H}+{int(X)-int(toolbar_W)}+{int(Y)+offset_y}') toolbar.update() def on_draw(): global img n_comp = len(ui_state.sliders) slider_vals = np.array([s.get() for s in ui_state.sliders], dtype=np.float32) # Run model sparingly mode = ui_state.mode.get() latent_start = ui_state.edit_layer_start.get() latent_end = ui_state.edit_layer_end.get() + 1 # save as exclusive, show as inclusive if cache.update(coords=slider_vals, comp=state.component_class, mode=mode, z=state.z, s=latent_start, e=latent_end): with torch.no_grad(): z_base = state.z - state.lat_slider_offset z_deltas = [0.0]*model.get_max_latents() z_delta_global = 0.0 n_comp = slider_vals.size act_deltas = {} if torch.is_tensor(state.act_slider_offset): act_deltas[layer_name] = -state.act_slider_offset for space in components.latent_types: assert space == model.latent_space_name(), \ 'Cannot mix latent spaces (for now)' for c in range(n_comp): coord = slider_vals[c] if coord == 0: continue edit_mode = components.types[c] if use_named_latents else mode # Activation offset if edit_mode in ['activation', 'both']: delta = components.X_comp[c] * components.X_stdev[c] * coord name = components.layer_names[c] if use_named_latents else layer_name act_deltas[name] = act_deltas.get(name, 0.0) + delta # Latent offset if edit_mode in ['latent', 'both']: delta = components.Z_comp[c] * components.Z_stdev[c] * coord edit_range = components.ranges[c] if use_named_latents else (latent_start, latent_end) full_range = (edit_range == (0, model.get_max_latents())) # Single or multiple offsets? if full_range: z_delta_global = z_delta_global + delta else: for l in range(*edit_range): z_deltas[l] = z_deltas[l] + delta # Apply activation deltas inst.remove_edits() for layer, delta in act_deltas.items(): inst.edit_layer(layer, offset=delta) # Evaluate has_offsets = any(torch.is_tensor(t) for t in z_deltas) z_final = apply_edit(z_base, z_delta_global) if has_offsets: z_final = [apply_edit(z_final, d) for d in z_deltas] img = model.forward(z_final).clamp(0.0, 1.0) app.draw(img) # Save necessary data to disk for later loading def export_direction(idx, button_frame): name = tk.StringVar(value='') num_strips = tk.IntVar(value=0) strip_width = tk.IntVar(value=5) slider_values = np.array([s.get() for s in ui_state.sliders]) slider_value = slider_values[idx] if (slider_values != 0).sum() > 1: print('Please modify only one slider') return elif slider_value == 0: print('Modify selected slider to set usable range (currently 0)') return popup = tk.Toplevel(root) popup.geometry("200x200+0+0") tk.Label(popup, text="Edit name").pack() tk.Entry(popup, textvariable=name).pack(pady=5) # tk.Scale(popup, from_=0, to=30, variable=num_strips, # resolution=1, orient=tk.HORIZONTAL, length=200, label='Image strips to export').pack() # tk.Scale(popup, from_=3, to=15, variable=strip_width, # resolution=1, orient=tk.HORIZONTAL, length=200, label='Image strip width').pack() tk.Button(popup, text='OK', command=popup.quit).pack() canceled = False def on_close(): nonlocal canceled canceled = True popup.quit() popup.protocol("WM_DELETE_WINDOW", on_close) x = button_frame.winfo_rootx() y = button_frame.winfo_rooty() w = int(button_frame.winfo_geometry().split('x')[0]) popup.geometry('%dx%d+%d+%d' % (180, 90, x + w, y)) popup.mainloop() popup.destroy() # Update slider name label = get_edit_name(idx, ui_state.edit_layer_start.get(), ui_state.edit_layer_end.get(), name.get()) ui_state.scales[idx].config(label=label) if canceled: return params = { 'name': name.get(), 'sigma_range': slider_value, 'component_index': idx, 'act_comp': components.X_comp[idx].detach().cpu().numpy(), 'lat_comp': components.Z_comp[idx].detach().cpu().numpy(), # either Z or W 'latent_space': model.latent_space_name(), 'act_stdev': components.X_stdev[idx].item(), 'lat_stdev': components.Z_stdev[idx].item(), 'model_name': model_name, 'output_class': ui_state.outclass.get(), # applied onto 'decomposition': { 'name': args.estimator, 'components': args.components, 'samples': args.n, 'layer': args.layer, 'class_name': state.component_class # computed from }, 'edit_type': ui_state.mode.get(), 'truncation': ui_state.truncation.get(), 'edit_start': ui_state.edit_layer_start.get(), 'edit_end': ui_state.edit_layer_end.get() + 1, # show as inclusive, save as exclusive 'example_seed': state.seed, } edit_mode_str = params['edit_type'] if edit_mode_str == 'latent': edit_mode_str = model.latent_space_name().lower() comp_class = state.component_class appl_class = params['output_class'] if comp_class != appl_class: comp_class = f'{comp_class}_onto_{appl_class}' file_ident = "{model}-{name}-{cls}-{est}-{mode}-{layer}-comp{idx}-range{start}-{end}".format( model=model_name, name=prettify_name(params['name']), cls=comp_class, est=args.estimator, mode=edit_mode_str, layer=args.layer, idx=idx, start=params['edit_start'], end=params['edit_end'], ) out_dir = Path(__file__).parent / 'out' / 'directions' makedirs(out_dir / file_ident, exist_ok=True) with open(out_dir / f"{file_ident}.pkl", 'wb') as outfile: pickle.dump(params, outfile) print(f'Direction "{name.get()}" saved as "{file_ident}.pkl"') batch_size = ui_state.batch_size.get() len_padded = ((num_strips.get() - 1) // batch_size + 1) * batch_size orig_seed = state.seed reset_sliders() # Limit max resolution max_H = 512 ratio = min(1.0, max_H / inst.output_shape[2]) strips = [[] for _ in range(len_padded)] for b in range(0, len_padded, batch_size): # Resample resample_latent((orig_seed + b) % np.iinfo(np.int32).max) sigmas = np.linspace(slider_value, -slider_value, strip_width.get(), dtype=np.float32) for sid, sigma in enumerate(sigmas): ui_state.sliders[idx].set(sigma) # Advance and show results on screen on_draw() root.update() app.update() batch_res = (255*img).byte().permute(0, 2, 3, 1).detach().cpu().numpy() for i, data in enumerate(batch_res): # Save individual name_nodots = file_ident.replace('.', '_') outname = out_dir / file_ident / f"{name_nodots}_ex{b+i}_{sid}.png" im = Image.fromarray(data) im = im.resize((int(ratio*im.size[0]), int(ratio*im.size[1])), Image.ANTIALIAS) im.save(outname) strips[b+i].append(data) for i, strip in enumerate(strips[:num_strips.get()]): print(f'Saving strip {i + 1}/{num_strips.get()}', end='\r', flush=True) data = np.hstack(pad_frames(strip)) im = Image.fromarray(data) im = im.resize((int(ratio*im.size[0]), int(ratio*im.size[1])), Image.ANTIALIAS) im.save(out_dir / file_ident / f"{file_ident}_ex{i}.png") # Reset to original state resample_latent(orig_seed) ui_state.sliders[idx].set(slider_value) # Shared by glumpy and tkinter def handle_keypress(code): if code == 65307: # ESC shutdown() elif code == 65360: # HOME reset_sliders() elif code == 114: # R pass #reset_sliders() def shutdown(): global pending_close pending_close = True def on_key_release(symbol, modifiers): handle_keypress(symbol) if __name__=='__main__': setup_model() setup_ui() resample_latent() pending_close = False while not pending_close: root.update() app.update() on_draw() reposition_toolbar() root.destroy()