import torch import os from PIL import Image import numpy as np from ipycanvas import Canvas import cv2 from visualize_attention_src.utils import get_image exp_dir = "saved_attention_map_results" style_name = "line_art" src_name = "cat" tgt_name = "dog" steps = ["20"] seed = "4" saved_dtype = "tensor" attn_map_raws = [] for step in steps: attn_map_name_wo_ext = f"attn_map_raw_{style_name}_src_{src_name}_tgt_{tgt_name}_activate_layer_(0, 0)(108, 140)_attn_map_step_{step}_seed_{seed}" # new if saved_dtype == 'uint8': attn_map_name = attn_map_name_wo_ext + '_uint8.npy' attn_map_path = os.path.join(exp_dir, attn_map_name) attn_map_raws.append(np.load(attn_map_path, allow_pickle=True)) else: attn_map_name = attn_map_name_wo_ext + '.pt' attn_map_path = os.path.join(exp_dir, attn_map_name) attn_map_raws.append(torch.load(attn_map_path)) print(attn_map_path) attn_map_path = os.path.join(exp_dir, attn_map_name) print(f"{step} is on memory") keys = [key for key in attn_map_raws[0].keys()] print(len(keys)) key = keys[0] ######################## tgt_idx = 3 # indicating the location of generated images. attn_map_paired_rgb_grid_name = f"{style_name}_src_{src_name}_tgt_{tgt_name}_scale_1.0_activate_layer_(0, 0)(108, 140)_seed_{seed}.png" attn_map_paired_rgb_grid_path = os.path.join(exp_dir, attn_map_paired_rgb_grid_name) print(attn_map_paired_rgb_grid_path) attn_map_paired_rgb_grid = Image.open(attn_map_paired_rgb_grid_path) attn_map_src_img = get_image(attn_map_paired_rgb_grid, row = 0, col = 0, image_size = 1024, grid_width = 10) attn_map_tgt_img = get_image(attn_map_paired_rgb_grid, row = 0, col = tgt_idx, image_size = 1024, grid_width = 10) h, w = 256, 256 num_of_grid = 64 plus_50 = 0 # key_idx_list = [0,2,4,6,8,10] key_idx_list = [6, 28] # (108 -> 0, 109 -> 1, ... , 140 -> 32) # if Swapping Attentio nin (108, 140) layer , use key_idx_list = [6, 28]. # 6==early upblock, 28==late upblock saved_attention_map_idx = [0] source_image = attn_map_src_img target_image = attn_map_tgt_img # resize source_image = source_image.resize((h, w)) target_image = target_image.resize((h, w)) # convert to numpy array source_image = np.array(source_image) target_image = np.array(target_image) canvas = Canvas(width=4 * w, height=h * len(key_idx_list), sync_image_data=True) canvas.put_image_data(source_image, w * 3, 0) canvas.put_image_data(target_image, 0, 0) canvas.put_image_data(source_image, w * 3, h) canvas.put_image_data(target_image, 0, h) # Display the canvas # display(canvas) def save_to_file(*args, **kwargs): canvas.to_file("my_file1.png") # Listen to changes on the ``image_data`` trait and call ``save_to_file`` when it changes. canvas.observe(save_to_file, "image_data") def on_click(x, y): cnt = 0 canvas.put_image_data(target_image, 0, 0) print(x, y) # draw a point canvas.fill_style = 'red' canvas.fill_circle(x, y, 4) for step_i, step in enumerate(range(len(saved_attention_map_idx))): attn_map_raw = attn_map_raws[step_i] for key_i, key_idx in enumerate(key_idx_list): key = keys[key_idx] num_of_grid = int(attn_map_raw[key].shape[-1] ** (0.5)) # normalize x,y grid_x_idx = int(x / (w / num_of_grid)) grid_y_idx = int(y / (h / num_of_grid)) print(grid_x_idx, grid_y_idx) grid_idx = grid_x_idx + grid_y_idx * num_of_grid attn_map = attn_map_raw[key][tgt_idx * 10:10 + tgt_idx * 10, grid_idx, :] attn_map = attn_map.sum(dim=0) attn_map = attn_map.reshape(num_of_grid, num_of_grid) # process attn_map to pil attn_map = attn_map.detach().cpu().numpy() # attn_map = attn_map / attn_map.max() # normalized_attn_map = attn_map normalized_attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-8) normalized_attn_map = 1.0 - normalized_attn_map heatmap = cv2.applyColorMap(np.uint8(255 * normalized_attn_map), cv2.COLORMAP_JET) heatmap = cv2.resize(heatmap, (w, h)) attn_map = normalized_attn_map * 255 attn_map = attn_map.astype(np.uint8) attn_map = cv2.cvtColor(attn_map, cv2.COLOR_GRAY2RGB) # attn_map = cv2.cvtColor(attn_map, cv2.COLORMAP_JET) attn_map = cv2.resize(attn_map, (w, h)) # draw attn_map canvas.put_image_data(attn_map, w + step_i * 4 * w, h * key_i) # canvas.put_image_data(attn_map, w , h*key_i) # blend attn_map and target image alpha = 0.85 blended_image = cv2.addWeighted(source_image, 1 - alpha, heatmap, alpha, 0) # draw blended image canvas.put_image_data(blended_image, w * 2 + step_i * 4 * w, h * key_i) cnt += 1 # Attach the event handler to the canvas canvas.on_mouse_down(on_click)