Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,532 Bytes
456aee9 483a977 456aee9 b154ac8 456aee9 faffa79 456aee9 faffa79 456aee9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
import numpy as np
from PIL import Image
from dino_clip_featextract import img_transform_inv
from my_ipadapter_model import image_grid
from ncut_pytorch import NCUT, kway_ncut, rgb_from_tsne_3d, convert_to_lab_color
from ncut_pytorch.affinity_gamma import find_gamma_by_degree
from einops import rearrange
import torch
def ncut_tsne_multiple_images(image_embeds, n_eig=50, gamma=0.5, degree=0.5):
b, l, c = image_embeds.shape
inp = image_embeds.flatten(end_dim=-2)
if gamma is None:
gamma = find_gamma_by_degree(inp, degree, distance='rbf')
eigvec, eigval = NCUT(n_eig, affinity_focal_gamma=gamma, distance='rbf', device='cuda').fit_transform(inp)
x3d, rgb = rgb_from_tsne_3d(eigvec, device='cuda', perplexity=50)
rgb = convert_to_lab_color(rgb)
rgb = rearrange(rgb, '(b l) c -> b l c', b=b)
eigvec = rearrange(eigvec, '(b l) c -> b l c', b=b)
return eigvec, rgb
def _kway_cluster_one_image(image_embeds, n_cluster, gamma=0.5, degree=0.5):
l, c = image_embeds.shape
inp = image_embeds.flatten(end_dim=-2)
if gamma is None:
gamma = find_gamma_by_degree(inp, degree, distance='rbf')
n_eig = n_cluster * 2 + 6
n_eig = min(n_eig, inp.shape[0]//2-1)
num_samples = min(1000, inp.shape[0]//2)
eigvec, eigval = NCUT(n_eig, num_sample=num_samples,
affinity_focal_gamma=gamma, distance='rbf', device='cuda').fit_transform(inp)
eigvec_continues = kway_ncut(eigvec[:, :n_cluster], return_continuous=True)
return eigvec_continues
def kway_cluster_per_image(image_embeds, n_cluster, gamma=0.5, degree=0.5):
eigvecs = []
for i in range(image_embeds.shape[0]):
eigvec = _kway_cluster_one_image(image_embeds[i], n_cluster, gamma, degree)
eigvecs.append(eigvec)
eigvecs = torch.stack(eigvecs)
return eigvecs
def kway_cluster_multiple_images(image_embeds, n_cluster, gamma=0.5, degree=0.5):
b, l, c = image_embeds.shape
inp = image_embeds.flatten(end_dim=-2)
if gamma is None:
gamma = find_gamma_by_degree(inp, degree, distance='rbf')
n_eig = n_cluster * 2 + 6
n_eig = min(n_eig, inp.shape[0]//2-1)
num_samples = min(1000, inp.shape[0]//2)
eigvec, eigval = NCUT(n_eig, num_sample=num_samples,
affinity_focal_gamma=gamma, distance='rbf', device='cuda').fit_transform(inp)
eigvec_continues = kway_ncut(eigvec[:, :n_cluster], return_continuous=True)
eigvec_continues = rearrange(eigvec_continues, '(b l) c -> b l c', b=b)
return eigvec_continues
def get_single_multi_discrete_rgbs(joint_rgbs, single_eigvecs):
n_cluster = single_eigvecs.shape[-1]
discrete_rgbs = np.zeros_like(joint_rgbs)
for i_img in range(joint_rgbs.shape[0]):
_rgb = joint_rgbs[i_img]
_eigvec = single_eigvecs[i_img].cpu().numpy()
_cluster_labels = _eigvec.argmax(-1)
_discrete_rgb = np.zeros_like(_rgb)
for i_cluster in range(n_cluster):
_discrete_rgb[_cluster_labels == i_cluster] = _rgb[_cluster_labels == i_cluster].mean(0)
discrete_rgbs[i_img] = _discrete_rgb
discrete_rgbs = discrete_rgbs * 255
discrete_rgbs = discrete_rgbs.astype(np.uint8)
return discrete_rgbs
def get_center_features(image_embeds, cluster_labels, n_cluster):
center_features = torch.zeros((n_cluster, image_embeds.shape[-1]))
for i_cluster in range(n_cluster):
mask = cluster_labels == i_cluster
if mask.sum() > 0:
center_features[i_cluster] = image_embeds[mask].mean(0)
else:
# center_features[i_cluster] = torch.zeros_like(image_embeds[0])
center_features[i_cluster] = torch.ones_like(image_embeds[0]) * 114514
return center_features
def cosine_similarity(A, B):
_A = A / A.norm(dim=-1, keepdim=True)
_B = B / B.norm(dim=-1, keepdim=True)
return _A @ _B.T
from scipy.optimize import linear_sum_assignment
def hungarian_match_centers(center_features1, center_features2):
dist = torch.cdist(center_features1, center_features2)
dist = dist.cpu().detach().numpy()
row_ind, col_ind = linear_sum_assignment(dist)
return col_ind
def argmin_matching(center_features1, center_features2):
dist = torch.cdist(center_features1, center_features2)
dist = dist.cpu().detach().numpy()
return np.argmin(dist, axis=-1)
def match_centers(image_embed1, image_embed2, eigvec1, eigvec2, match_method='hungarian'):
cluster_label1 = eigvec1.argmax(-1).cpu().numpy()
cluster_label2 = eigvec2.argmax(-1).cpu().numpy()
n_cluster = eigvec1.shape[-1]
center_features1 = get_center_features(image_embed1, cluster_label1, n_cluster=n_cluster)
center_features2 = get_center_features(image_embed2, cluster_label2, n_cluster=n_cluster)
if match_method == 'hungarian':
one_to_one_mapping = hungarian_match_centers(center_features1, center_features2)
elif match_method == 'argmin':
one_to_one_mapping = argmin_matching(center_features1, center_features2)
return one_to_one_mapping
def match_centers_three_images(image_embeds, eigvecs, match_method='hungarian'):
# image_embeds: b, l, c; b = 3, A2, A1, B1
# eigvecs: b, l
A2_to_A1 = match_centers(image_embeds[0], image_embeds[1], eigvecs[0], eigvecs[1], match_method=match_method)
A1_to_B1 = match_centers(image_embeds[1], image_embeds[2], eigvecs[1], eigvecs[2], match_method=match_method)
return A2_to_A1, A1_to_B1
def match_centers_two_images(image_embed1, image_embed2, eigvec1, eigvec2, match_method='hungarian'):
one_to_one_mapping = match_centers(image_embed1, image_embed2, eigvec1, eigvec2, match_method=match_method)
return one_to_one_mapping
def plot_clusters(image, eigvec, cluster_order, hw=16):
cluster_images = []
img = img_transform_inv(image).resize((128, 128), resample=Image.Resampling.NEAREST)
for idx_cluster in cluster_order:
mask = eigvec.argmax(-1) == idx_cluster
mask = mask.cpu().numpy()[1:].reshape(hw, hw)
mask = (mask * 255).astype(np.uint8)
mask = Image.fromarray(mask).resize((128, 128), resample=Image.Resampling.NEAREST)
# superimpose
mask = np.array(mask).astype(np.float32) / 255
_img = np.array(img).astype(np.float32) / 255
mask = np.stack([mask] * 3, axis=-1)
mask[mask == 0] = 0.1
_img = _img * mask
_img = _img * 255
_img = _img.astype(np.uint8)
cluster_images.append(Image.fromarray(_img))
return cluster_images
def grid_one_image(image, eigvec, cluster_order, discrete_rgb, hw=16, n_cols=10):
cluster_images = plot_clusters(image, eigvec, cluster_order, hw)
img = img_transform_inv(image).resize((128, 128), resample=Image.Resampling.NEAREST)
ncut_image = discrete_rgb[1:].reshape(hw, hw, 3)
ncut_image = Image.fromarray(ncut_image).resize((128, 128), resample=Image.Resampling.NEAREST)
# extend cluster_images to n_cols
num_missing = n_cols - len(cluster_images) % n_cols
num_missing = 0 if num_missing == n_cols else num_missing
_img_append = Image.fromarray(np.zeros((128, 128, 3), dtype=np.uint8))
cluster_images.extend([_img_append] * num_missing)
# add img and ncut_image before each row
prepend_images = [img, ncut_image]
n_rows = len(cluster_images) // n_cols
new_cluster_images = []
for i_row in range(n_rows):
image_list = prepend_images + cluster_images[i_row * n_cols:(i_row + 1) * n_cols]
new_cluster_images.append(image_list)
return new_cluster_images
def grid_multiple_images(images, eigvecs, cluster_orders, discrete_rgbs, hw=16, n_cols=10):
grid_images = []
for image, eigvec, cluster_order, discrete_rgb in zip(images, eigvecs, cluster_orders, discrete_rgbs):
grid_images.append(grid_one_image(image, eigvec, cluster_order, discrete_rgb, hw, n_cols))
interleave_images = []
for i_row in range(len(grid_images[0])):
for i_img in range(len(grid_images)):
interleave_images.append(grid_images[i_img][i_row])
return interleave_images
def get_correspondence_plot(images, eigvecs, cluster_orders, discrete_rgbs, hw=16, n_cols=10):
n_cluster = eigvecs.shape[-1]
n_cols = min(n_cols, n_cluster)
interleave_images = grid_multiple_images(images, eigvecs, cluster_orders, discrete_rgbs, hw, n_cols)
n_row = len(interleave_images)
n_cols = len(interleave_images[0])
grid = image_grid(sum(interleave_images, []), n_row, n_cols)
return grid |