File size: 3,956 Bytes
5238ef9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e0de36
ab189a8
7e0de36
5238ef9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e0de36
5238ef9
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import imp
import torch
import pickle

from .util import *
from .spherical_kmeans import MiniBatchSphericalKMeans as sKmeans




truncation = 0.5
stop_idx = 11
n_clusters = 18

clusterer = pickle.load(open('./pretrained_models/ris/catalog.pkl', 'rb'))

labels2idx = {
	'nose': 0,
	'eyes': 1,
	'mouth': 2,
	'hair': 3,
	'background': 4,
	'cheek': 5,
	'neck': 6,
	'clothes': 7,
}

labels_map = {
	0: torch.tensor([7]),
	1: torch.tensor([1,6]),
	2: torch.tensor([4]),
	3: torch.tensor([0,3,5,8,10,15,16]),
	4: torch.tensor([11,13,14]),
	5: torch.tensor([9]),
	6: torch.tensor([17]),
	7: torch.tensor([2,12]),
}

lables2idx = dict((v,k) for k,v in labels2idx.items())
n_class = len(lables2idx)

segid_map = dict.fromkeys(labels_map[0].tolist(), 0)
segid_map.update(dict.fromkeys(labels_map[1].tolist(), 1))
segid_map.update(dict.fromkeys(labels_map[2].tolist(), 2))
segid_map.update(dict.fromkeys(labels_map[3].tolist(), 3))
segid_map.update(dict.fromkeys(labels_map[4].tolist(), 4))
segid_map.update(dict.fromkeys(labels_map[5].tolist(), 5))
segid_map.update(dict.fromkeys(labels_map[6].tolist(), 6))
segid_map.update(dict.fromkeys(labels_map[7].tolist(), 7))

torch.manual_seed(0)


# compute M given a style code.
@torch.no_grad()
def compute_M(w, generator, weights_deltas=None, device='cuda'):
	M = []
	
	# get segmentation
	# _, outputs = generator(w, is_cluster=1)
	_, outputs = generator(w, weights_deltas=weights_deltas)
	cluster_layer = outputs[stop_idx][0]
	activation = flatten_act(cluster_layer)
	seg_mask = clusterer.predict(activation)
	b,c,h,w = cluster_layer.size()

	# create masks for each feature
	all_seg_mask = []
	seg_mask = torch.from_numpy(seg_mask).view(b,1,h,w,1).to(device)
	
	for key in range(n_class):
		# combine masks for all indices for a particular segmentation class
		indices = labels_map[key].view(1,1,1,1,-1) 
		key_mask = (seg_mask == indices.to(device)).any(-1) #[b,1,h,w]
		all_seg_mask.append(key_mask)
		
	all_seg_mask = torch.stack(all_seg_mask, 1)

	# go through each activation layer and compute M
	for layer_idx in range(len(outputs)):
		layer = outputs[layer_idx][1].to(device)
		b,c,h,w = layer.size()
		layer = F.instance_norm(layer)
		layer = layer.pow(2)
		
		# resize the segmentation masks to current activations' resolution
		layer_seg_mask = F.interpolate(all_seg_mask.flatten(0,1).float(), align_corners=False, 
									 size=(h,w), mode='bilinear').view(b,-1,1,h,w)
		
		masked_layer = layer.unsqueeze(1) * layer_seg_mask # [b,k,c,h,w]
		masked_layer = (masked_layer.sum([3,4])/ (h*w))#[b,k,c]

		M.append(masked_layer.to(device))

	M = torch.cat(M, -1) #[b, k, c]
	
	# softmax to assign each channel to a particular segmentation class
	M = F.softmax(M/.1, 1)
	# simple thresholding
	M = (M>.8).float()
	
	# zero out torgb transfers, from https://arxiv.org/abs/2011.12799
	for i in range(n_class):
		part_M = style2list(M[:, i])
		for j in range(len(part_M)):
			if j in rgb_layer_idx:
				part_M[j].zero_()
		part_M = list2style(part_M)
		M[:, i] = part_M

	return M

def blend_latents (source_latent, ref_latent, generator, src_deltas=None, ref_deltas=None, device='cuda'):
    #print(source_latent.shape)
    source = generator.get_latent(source_latent, truncation=1, is_latent=True)
    #print(ref_latent.shape)
    ref = generator.get_latent(ref_latent[0].unsqueeze(0), truncation=1, is_latent=True)
    source_M = compute_M(source, generator, weights_deltas=src_deltas, device='cpu')
    ref_M = compute_M(ref, generator, weights_deltas=ref_deltas, device='cpu')

    blend_deltas = src_deltas
    
    max_M = torch.max(source_M.expand_as(ref_M), ref_M)
    max_M = add_pose(max_M, labels2idx)
    idx = labels2idx['hair']
    
    part_M = max_M[:, idx].to(device)
    part_M_mask = style2list(part_M)
    
    blend = style2list((add_direction(source, ref, part_M, 1.3)))
    blend_out, _ = generator(blend, weights_deltas=blend_deltas)
    #print(blend_out.shape)

    return blend_out, blend