Spaces:
Runtime error
Runtime error
import imp | |
import torch | |
import pickle | |
from .util import * | |
from .spherical_kmeans import MiniBatchSphericalKMeans as sKmeans | |
truncation = 0.5 | |
stop_idx = 11 | |
n_clusters = 18 | |
clusterer = pickle.load(open('./pretrained_models/ris/catalog.pkl', 'rb')) | |
labels2idx = { | |
'nose': 0, | |
'eyes': 1, | |
'mouth': 2, | |
'hair': 3, | |
'background': 4, | |
'cheek': 5, | |
'neck': 6, | |
'clothes': 7, | |
} | |
labels_map = { | |
0: torch.tensor([7]), | |
1: torch.tensor([1,6]), | |
2: torch.tensor([4]), | |
3: torch.tensor([0,3,5,8,10,15,16]), | |
4: torch.tensor([11,13,14]), | |
5: torch.tensor([9]), | |
6: torch.tensor([17]), | |
7: torch.tensor([2,12]), | |
} | |
lables2idx = dict((v,k) for k,v in labels2idx.items()) | |
n_class = len(lables2idx) | |
segid_map = dict.fromkeys(labels_map[0].tolist(), 0) | |
segid_map.update(dict.fromkeys(labels_map[1].tolist(), 1)) | |
segid_map.update(dict.fromkeys(labels_map[2].tolist(), 2)) | |
segid_map.update(dict.fromkeys(labels_map[3].tolist(), 3)) | |
segid_map.update(dict.fromkeys(labels_map[4].tolist(), 4)) | |
segid_map.update(dict.fromkeys(labels_map[5].tolist(), 5)) | |
segid_map.update(dict.fromkeys(labels_map[6].tolist(), 6)) | |
segid_map.update(dict.fromkeys(labels_map[7].tolist(), 7)) | |
torch.manual_seed(0) | |
# compute M given a style code. | |
def compute_M(w, generator, weights_deltas=None, device='cuda'): | |
M = [] | |
# get segmentation | |
# _, outputs = generator(w, is_cluster=1) | |
_, outputs = generator(w, weights_deltas=weights_deltas) | |
cluster_layer = outputs[stop_idx][0] | |
activation = flatten_act(cluster_layer) | |
seg_mask = clusterer.predict(activation) | |
b,c,h,w = cluster_layer.size() | |
# create masks for each feature | |
all_seg_mask = [] | |
seg_mask = torch.from_numpy(seg_mask).view(b,1,h,w,1).to(device) | |
for key in range(n_class): | |
# combine masks for all indices for a particular segmentation class | |
indices = labels_map[key].view(1,1,1,1,-1) | |
key_mask = (seg_mask == indices.to(device)).any(-1) #[b,1,h,w] | |
all_seg_mask.append(key_mask) | |
all_seg_mask = torch.stack(all_seg_mask, 1) | |
# go through each activation layer and compute M | |
for layer_idx in range(len(outputs)): | |
layer = outputs[layer_idx][1].to(device) | |
b,c,h,w = layer.size() | |
layer = F.instance_norm(layer) | |
layer = layer.pow(2) | |
# resize the segmentation masks to current activations' resolution | |
layer_seg_mask = F.interpolate(all_seg_mask.flatten(0,1).float(), align_corners=False, | |
size=(h,w), mode='bilinear').view(b,-1,1,h,w) | |
masked_layer = layer.unsqueeze(1) * layer_seg_mask # [b,k,c,h,w] | |
masked_layer = (masked_layer.sum([3,4])/ (h*w))#[b,k,c] | |
M.append(masked_layer.to(device)) | |
M = torch.cat(M, -1) #[b, k, c] | |
# softmax to assign each channel to a particular segmentation class | |
M = F.softmax(M/.1, 1) | |
# simple thresholding | |
M = (M>.8).float() | |
# zero out torgb transfers, from https://arxiv.org/abs/2011.12799 | |
for i in range(n_class): | |
part_M = style2list(M[:, i]) | |
for j in range(len(part_M)): | |
if j in rgb_layer_idx: | |
part_M[j].zero_() | |
part_M = list2style(part_M) | |
M[:, i] = part_M | |
return M | |
def blend_latents (source_latent, ref_latent, generator, src_deltas=None, ref_deltas=None, device='cuda'): | |
#print(source_latent.shape) | |
source = generator.get_latent(source_latent[0].unsqueeze(0), truncation=1, is_latent=True) | |
#print(ref_latent.shape) | |
ref = generator.get_latent(ref_latent[0].unsqueeze(0), truncation=1, is_latent=True) | |
source_M = compute_M(source, generator, weights_deltas=src_deltas, device='cpu') | |
ref_M = compute_M(ref, generator, weights_deltas=ref_deltas, device='cpu') | |
blend_deltas = src_deltas | |
max_M = torch.max(source_M.expand_as(ref_M), ref_M) | |
max_M = add_pose(max_M, labels2idx) | |
idx = labels2idx['hair'] | |
part_M = max_M[:, idx].to(device) | |
part_M_mask = style2list(part_M) | |
blend = style2list((add_direction(source, ref, part_M, 1.3))) | |
blend_out, _ = generator(blend, weights_deltas=blend_deltas) | |
#print(blend_out.shape) | |
return blend_out, blend |