ethanNeuralImage's picture
start implementing PTI, RIS implemented
ab189a8
raw
history blame contribute delete
No virus
3.96 kB
import imp
import torch
import pickle
from .util import *
from .spherical_kmeans import MiniBatchSphericalKMeans as sKmeans
truncation = 0.5
stop_idx = 11
n_clusters = 18
clusterer = pickle.load(open('./pretrained_models/ris/catalog.pkl', 'rb'))
labels2idx = {
'nose': 0,
'eyes': 1,
'mouth': 2,
'hair': 3,
'background': 4,
'cheek': 5,
'neck': 6,
'clothes': 7,
}
labels_map = {
0: torch.tensor([7]),
1: torch.tensor([1,6]),
2: torch.tensor([4]),
3: torch.tensor([0,3,5,8,10,15,16]),
4: torch.tensor([11,13,14]),
5: torch.tensor([9]),
6: torch.tensor([17]),
7: torch.tensor([2,12]),
}
lables2idx = dict((v,k) for k,v in labels2idx.items())
n_class = len(lables2idx)
segid_map = dict.fromkeys(labels_map[0].tolist(), 0)
segid_map.update(dict.fromkeys(labels_map[1].tolist(), 1))
segid_map.update(dict.fromkeys(labels_map[2].tolist(), 2))
segid_map.update(dict.fromkeys(labels_map[3].tolist(), 3))
segid_map.update(dict.fromkeys(labels_map[4].tolist(), 4))
segid_map.update(dict.fromkeys(labels_map[5].tolist(), 5))
segid_map.update(dict.fromkeys(labels_map[6].tolist(), 6))
segid_map.update(dict.fromkeys(labels_map[7].tolist(), 7))
torch.manual_seed(0)
# compute M given a style code.
@torch.no_grad()
def compute_M(w, generator, weights_deltas=None, device='cuda'):
M = []
# get segmentation
# _, outputs = generator(w, is_cluster=1)
_, outputs = generator(w, weights_deltas=weights_deltas)
cluster_layer = outputs[stop_idx][0]
activation = flatten_act(cluster_layer)
seg_mask = clusterer.predict(activation)
b,c,h,w = cluster_layer.size()
# create masks for each feature
all_seg_mask = []
seg_mask = torch.from_numpy(seg_mask).view(b,1,h,w,1).to(device)
for key in range(n_class):
# combine masks for all indices for a particular segmentation class
indices = labels_map[key].view(1,1,1,1,-1)
key_mask = (seg_mask == indices.to(device)).any(-1) #[b,1,h,w]
all_seg_mask.append(key_mask)
all_seg_mask = torch.stack(all_seg_mask, 1)
# go through each activation layer and compute M
for layer_idx in range(len(outputs)):
layer = outputs[layer_idx][1].to(device)
b,c,h,w = layer.size()
layer = F.instance_norm(layer)
layer = layer.pow(2)
# resize the segmentation masks to current activations' resolution
layer_seg_mask = F.interpolate(all_seg_mask.flatten(0,1).float(), align_corners=False,
size=(h,w), mode='bilinear').view(b,-1,1,h,w)
masked_layer = layer.unsqueeze(1) * layer_seg_mask # [b,k,c,h,w]
masked_layer = (masked_layer.sum([3,4])/ (h*w))#[b,k,c]
M.append(masked_layer.to(device))
M = torch.cat(M, -1) #[b, k, c]
# softmax to assign each channel to a particular segmentation class
M = F.softmax(M/.1, 1)
# simple thresholding
M = (M>.8).float()
# zero out torgb transfers, from https://arxiv.org/abs/2011.12799
for i in range(n_class):
part_M = style2list(M[:, i])
for j in range(len(part_M)):
if j in rgb_layer_idx:
part_M[j].zero_()
part_M = list2style(part_M)
M[:, i] = part_M
return M
def blend_latents (source_latent, ref_latent, generator, src_deltas=None, ref_deltas=None, device='cuda'):
#print(source_latent.shape)
source = generator.get_latent(source_latent, truncation=1, is_latent=True)
#print(ref_latent.shape)
ref = generator.get_latent(ref_latent[0].unsqueeze(0), truncation=1, is_latent=True)
source_M = compute_M(source, generator, weights_deltas=src_deltas, device='cpu')
ref_M = compute_M(ref, generator, weights_deltas=ref_deltas, device='cpu')
blend_deltas = src_deltas
max_M = torch.max(source_M.expand_as(ref_M), ref_M)
max_M = add_pose(max_M, labels2idx)
idx = labels2idx['hair']
part_M = max_M[:, idx].to(device)
part_M_mask = style2list(part_M)
blend = style2list((add_direction(source, ref, part_M, 1.3)))
blend_out, _ = generator(blend, weights_deltas=blend_deltas)
#print(blend_out.shape)
return blend_out, blend