Spaces:
Runtime error
Runtime error
import torch | |
import os | |
from PIL import Image | |
import numpy as np | |
from ipycanvas import Canvas | |
import cv2 | |
from visualize_attention_src.utils import get_image | |
exp_dir = "saved_attention_map_results" | |
style_name = "line_art" | |
src_name = "cat" | |
tgt_name = "dog" | |
steps = ["20"] | |
seed = "4" | |
saved_dtype = "tensor" | |
attn_map_raws = [] | |
for step in steps: | |
attn_map_name_wo_ext = f"attn_map_raw_{style_name}_src_{src_name}_tgt_{tgt_name}_activate_layer_(0, 0)(108, 140)_attn_map_step_{step}_seed_{seed}" # new | |
if saved_dtype == 'uint8': | |
attn_map_name = attn_map_name_wo_ext + '_uint8.npy' | |
attn_map_path = os.path.join(exp_dir, attn_map_name) | |
attn_map_raws.append(np.load(attn_map_path, allow_pickle=True)) | |
else: | |
attn_map_name = attn_map_name_wo_ext + '.pt' | |
attn_map_path = os.path.join(exp_dir, attn_map_name) | |
attn_map_raws.append(torch.load(attn_map_path)) | |
print(attn_map_path) | |
attn_map_path = os.path.join(exp_dir, attn_map_name) | |
print(f"{step} is on memory") | |
keys = [key for key in attn_map_raws[0].keys()] | |
print(len(keys)) | |
key = keys[0] | |
######################## | |
tgt_idx = 3 # indicating the location of generated images. | |
attn_map_paired_rgb_grid_name = f"{style_name}_src_{src_name}_tgt_{tgt_name}_scale_1.0_activate_layer_(0, 0)(108, 140)_seed_{seed}.png" | |
attn_map_paired_rgb_grid_path = os.path.join(exp_dir, attn_map_paired_rgb_grid_name) | |
print(attn_map_paired_rgb_grid_path) | |
attn_map_paired_rgb_grid = Image.open(attn_map_paired_rgb_grid_path) | |
attn_map_src_img = get_image(attn_map_paired_rgb_grid, row = 0, col = 0, image_size = 1024, grid_width = 10) | |
attn_map_tgt_img = get_image(attn_map_paired_rgb_grid, row = 0, col = tgt_idx, image_size = 1024, grid_width = 10) | |
h, w = 256, 256 | |
num_of_grid = 64 | |
plus_50 = 0 | |
# key_idx_list = [0,2,4,6,8,10] | |
key_idx_list = [6, 28] | |
# (108 -> 0, 109 -> 1, ... , 140 -> 32) | |
# if Swapping Attentio nin (108, 140) layer , use key_idx_list = [6, 28]. | |
# 6==early upblock, 28==late upblock | |
saved_attention_map_idx = [0] | |
source_image = attn_map_src_img | |
target_image = attn_map_tgt_img | |
# resize | |
source_image = source_image.resize((h, w)) | |
target_image = target_image.resize((h, w)) | |
# convert to numpy array | |
source_image = np.array(source_image) | |
target_image = np.array(target_image) | |
canvas = Canvas(width=4 * w, height=h * len(key_idx_list), sync_image_data=True) | |
canvas.put_image_data(source_image, w * 3, 0) | |
canvas.put_image_data(target_image, 0, 0) | |
canvas.put_image_data(source_image, w * 3, h) | |
canvas.put_image_data(target_image, 0, h) | |
# Display the canvas | |
# display(canvas) | |
def save_to_file(*args, **kwargs): | |
canvas.to_file("my_file1.png") | |
# Listen to changes on the ``image_data`` trait and call ``save_to_file`` when it changes. | |
canvas.observe(save_to_file, "image_data") | |
def on_click(x, y): | |
cnt = 0 | |
canvas.put_image_data(target_image, 0, 0) | |
print(x, y) | |
# draw a point | |
canvas.fill_style = 'red' | |
canvas.fill_circle(x, y, 4) | |
for step_i, step in enumerate(range(len(saved_attention_map_idx))): | |
attn_map_raw = attn_map_raws[step_i] | |
for key_i, key_idx in enumerate(key_idx_list): | |
key = keys[key_idx] | |
num_of_grid = int(attn_map_raw[key].shape[-1] ** (0.5)) | |
# normalize x,y | |
grid_x_idx = int(x / (w / num_of_grid)) | |
grid_y_idx = int(y / (h / num_of_grid)) | |
print(grid_x_idx, grid_y_idx) | |
grid_idx = grid_x_idx + grid_y_idx * num_of_grid | |
attn_map = attn_map_raw[key][tgt_idx * 10:10 + tgt_idx * 10, grid_idx, :] | |
attn_map = attn_map.sum(dim=0) | |
attn_map = attn_map.reshape(num_of_grid, num_of_grid) | |
# process attn_map to pil | |
attn_map = attn_map.detach().cpu().numpy() | |
# attn_map = attn_map / attn_map.max() | |
# normalized_attn_map = attn_map | |
normalized_attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-8) | |
normalized_attn_map = 1.0 - normalized_attn_map | |
heatmap = cv2.applyColorMap(np.uint8(255 * normalized_attn_map), cv2.COLORMAP_JET) | |
heatmap = cv2.resize(heatmap, (w, h)) | |
attn_map = normalized_attn_map * 255 | |
attn_map = attn_map.astype(np.uint8) | |
attn_map = cv2.cvtColor(attn_map, cv2.COLOR_GRAY2RGB) | |
# attn_map = cv2.cvtColor(attn_map, cv2.COLORMAP_JET) | |
attn_map = cv2.resize(attn_map, (w, h)) | |
# draw attn_map | |
canvas.put_image_data(attn_map, w + step_i * 4 * w, h * key_i) | |
# canvas.put_image_data(attn_map, w , h*key_i) | |
# blend attn_map and target image | |
alpha = 0.85 | |
blended_image = cv2.addWeighted(source_image, 1 - alpha, heatmap, alpha, 0) | |
# draw blended image | |
canvas.put_image_data(blended_image, w * 2 + step_i * 4 * w, h * key_i) | |
cnt += 1 | |
# Attach the event handler to the canvas | |
canvas.on_mouse_down(on_click) |