import imp import torch import pickle from .util import * from .spherical_kmeans import MiniBatchSphericalKMeans as sKmeans truncation = 0.5 stop_idx = 11 n_clusters = 18 clusterer = pickle.load(open('./pretrained_models/ris/catalog.pkl', 'rb')) labels2idx = { 'nose': 0, 'eyes': 1, 'mouth': 2, 'hair': 3, 'background': 4, 'cheek': 5, 'neck': 6, 'clothes': 7, } labels_map = { 0: torch.tensor([7]), 1: torch.tensor([1,6]), 2: torch.tensor([4]), 3: torch.tensor([0,3,5,8,10,15,16]), 4: torch.tensor([11,13,14]), 5: torch.tensor([9]), 6: torch.tensor([17]), 7: torch.tensor([2,12]), } lables2idx = dict((v,k) for k,v in labels2idx.items()) n_class = len(lables2idx) segid_map = dict.fromkeys(labels_map[0].tolist(), 0) segid_map.update(dict.fromkeys(labels_map[1].tolist(), 1)) segid_map.update(dict.fromkeys(labels_map[2].tolist(), 2)) segid_map.update(dict.fromkeys(labels_map[3].tolist(), 3)) segid_map.update(dict.fromkeys(labels_map[4].tolist(), 4)) segid_map.update(dict.fromkeys(labels_map[5].tolist(), 5)) segid_map.update(dict.fromkeys(labels_map[6].tolist(), 6)) segid_map.update(dict.fromkeys(labels_map[7].tolist(), 7)) torch.manual_seed(0) # compute M given a style code. @torch.no_grad() def compute_M(w, generator, weights_deltas=None, device='cuda'): M = [] # get segmentation # _, outputs = generator(w, is_cluster=1) _, outputs = generator(w, weights_deltas=weights_deltas) cluster_layer = outputs[stop_idx][0] activation = flatten_act(cluster_layer) seg_mask = clusterer.predict(activation) b,c,h,w = cluster_layer.size() # create masks for each feature all_seg_mask = [] seg_mask = torch.from_numpy(seg_mask).view(b,1,h,w,1).to(device) for key in range(n_class): # combine masks for all indices for a particular segmentation class indices = labels_map[key].view(1,1,1,1,-1) key_mask = (seg_mask == indices.to(device)).any(-1) #[b,1,h,w] all_seg_mask.append(key_mask) all_seg_mask = torch.stack(all_seg_mask, 1) # go through each activation layer and compute M for layer_idx in range(len(outputs)): layer = outputs[layer_idx][1].to(device) b,c,h,w = layer.size() layer = F.instance_norm(layer) layer = layer.pow(2) # resize the segmentation masks to current activations' resolution layer_seg_mask = F.interpolate(all_seg_mask.flatten(0,1).float(), align_corners=False, size=(h,w), mode='bilinear').view(b,-1,1,h,w) masked_layer = layer.unsqueeze(1) * layer_seg_mask # [b,k,c,h,w] masked_layer = (masked_layer.sum([3,4])/ (h*w))#[b,k,c] M.append(masked_layer.to(device)) M = torch.cat(M, -1) #[b, k, c] # softmax to assign each channel to a particular segmentation class M = F.softmax(M/.1, 1) # simple thresholding M = (M>.8).float() # zero out torgb transfers, from https://arxiv.org/abs/2011.12799 for i in range(n_class): part_M = style2list(M[:, i]) for j in range(len(part_M)): if j in rgb_layer_idx: part_M[j].zero_() part_M = list2style(part_M) M[:, i] = part_M return M def blend_latents (source_latent, ref_latent, generator, src_deltas=None, ref_deltas=None, device='cuda'): #print(source_latent.shape) source = generator.get_latent(source_latent, truncation=1, is_latent=True) #print(ref_latent.shape) ref = generator.get_latent(ref_latent[0].unsqueeze(0), truncation=1, is_latent=True) source_M = compute_M(source, generator, weights_deltas=src_deltas, device='cpu') ref_M = compute_M(ref, generator, weights_deltas=ref_deltas, device='cpu') blend_deltas = src_deltas max_M = torch.max(source_M.expand_as(ref_M), ref_M) max_M = add_pose(max_M, labels2idx) idx = labels2idx['hair'] part_M = max_M[:, idx].to(device) part_M_mask = style2list(part_M) blend = style2list((add_direction(source, ref, part_M, 1.3))) blend_out, _ = generator(blend, weights_deltas=blend_deltas) #print(blend_out.shape) return blend_out, blend