Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.optim as optim | |
| import torch.backends.cudnn as cudnn | |
| cudnn.benchmark = True | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import os | |
| import sys | |
| from tqdm import tqdm as tqdm | |
| import pickle | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| from spherical_kmeans import MiniBatchSphericalKMeans as sKmeans | |
| from argparse import Namespace | |
| from torch.utils.data import Dataset, DataLoader | |
| from torchvision.models import vgg19 | |
| import glob | |
| from pathlib import Path | |
| import lpips | |
| import argparse | |
| import gc | |
| import cv2 | |
| from models.pti.e4e_projection import projection | |
| from model import * | |
| from util import * | |
| from e4e.models.psp import pSp | |
| import torchvision.transforms as transforms | |
| from torch.nn import DataParallel | |
| import torchvision.transforms.functional as TF | |
| from FaceQualityMetrics.utils import FaceMetric | |
| from hyperstyle.utils.model_utils import load_model | |
| from configs.paths_config import model_paths | |
| from models.pti.manipulator import Manipulator | |
| from models.pti.wrapper import Generator_wrapper | |
| def run_inversion(inputs, net, n_iters_per_batch, return_intermediate_results=False, resize_outputs=False, weights_deltas=None): | |
| y_hat, latent, weights_deltas, codes = None, None, weights_deltas, None | |
| if return_intermediate_results: | |
| results_batch = {idx: [] for idx in range(inputs.shape[0])} | |
| results_latent = {idx: [] for idx in range(inputs.shape[0])} | |
| results_deltas = {idx: [] for idx in range(inputs.shape[0])} | |
| else: | |
| results_batch, results_latent, results_deltas = None, None, None | |
| if weights_deltas is None: | |
| for iter in range(n_iters_per_batch): | |
| y_hat, latent, weights_deltas, codes, _ = net.forward(inputs, | |
| y_hat=y_hat, | |
| codes=codes, | |
| weights_deltas=weights_deltas, | |
| return_latents=True, | |
| resize=resize_outputs, | |
| randomize_noise=False, | |
| return_weight_deltas_and_codes=True) | |
| # weights_deltas[14]= None | |
| # weights_deltas[20]= None | |
| # weights_deltas[21]= None | |
| # weights_deltas[23]= None | |
| # weights_deltas[24]= None | |
| if return_intermediate_results: | |
| store_intermediate_results(results_batch, results_latent, results_deltas, y_hat, latent, weights_deltas) | |
| # resize input to 256 before feeding into next iteration | |
| y_hat = net.face_pool(y_hat) | |
| else: | |
| for iter in range(n_iters_per_batch): | |
| y_hat, latent, _, codes, _ = net.forward(inputs, | |
| y_hat=y_hat, | |
| codes=codes, | |
| weights_deltas=weights_deltas, | |
| return_latents=True, | |
| resize=resize_outputs, | |
| randomize_noise=False, | |
| return_weight_deltas_and_codes=True) | |
| if return_intermediate_results: | |
| store_intermediate_results(results_batch, results_latent, results_deltas, y_hat, latent, weights_deltas) | |
| # resize input to 256 before feeding into next iteration | |
| y_hat = net.face_pool(y_hat) | |
| if return_intermediate_results: | |
| return results_batch, results_latent, results_deltas | |
| return y_hat, latent, weights_deltas, codes | |
| def store_intermediate_results(results_batch, results_latent, results_deltas, y_hat, latent, weights_deltas): | |
| for idx in range(y_hat.shape[0]): | |
| results_batch[idx].append(y_hat[idx]) | |
| results_latent[idx].append(latent[idx].cpu().numpy()) | |
| results_deltas[idx].append([w[idx].cpu().numpy() if w is not None else None for w in weights_deltas]) | |
| # compute M given a style code. | |
| def compute_M(w, weights_deltas=None, device='cuda'): | |
| M = [] | |
| # get segmentation | |
| # _, outputs = generator(w, is_cluster=1) | |
| _, outputs = generator(w, weights_deltas=weights_deltas) | |
| cluster_layer = outputs[stop_idx][0] | |
| activation = flatten_act(cluster_layer) | |
| seg_mask = clusterer.predict(activation) | |
| b,c,h,w = cluster_layer.size() | |
| # create masks for each feature | |
| all_seg_mask = [] | |
| seg_mask = torch.from_numpy(seg_mask).view(b,1,h,w,1).to(device) | |
| for key in range(n_class): | |
| # combine masks for all indices for a particular segmentation class | |
| indices = labels_map[key].view(1,1,1,1,-1) | |
| key_mask = (seg_mask == indices.to(device)).any(-1) #[b,1,h,w] | |
| all_seg_mask.append(key_mask) | |
| all_seg_mask = torch.stack(all_seg_mask, 1) | |
| # go through each activation layer and compute M | |
| for layer_idx in range(len(outputs)): | |
| layer = outputs[layer_idx][1].to(device) | |
| b,c,h,w = layer.size() | |
| layer = F.instance_norm(layer) | |
| layer = layer.pow(2) | |
| # resize the segmentation masks to current activations' resolution | |
| layer_seg_mask = F.interpolate(all_seg_mask.flatten(0,1).float(), align_corners=False, | |
| size=(h,w), mode='bilinear').view(b,-1,1,h,w) | |
| masked_layer = layer.unsqueeze(1) * layer_seg_mask # [b,k,c,h,w] | |
| masked_layer = (masked_layer.sum([3,4])/ (h*w))#[b,k,c] | |
| M.append(masked_layer.to(device)) | |
| M = torch.cat(M, -1) #[b, k, c] | |
| # softmax to assign each channel to a particular segmentation class | |
| M = F.softmax(M/.1, 1) | |
| # simple thresholding | |
| M = (M>.8).float() | |
| # zero out torgb transfers, from https://arxiv.org/abs/2011.12799 | |
| for i in range(n_class): | |
| part_M = style2list(M[:, i]) | |
| for j in range(len(part_M)): | |
| if j in rgb_layer_idx: | |
| part_M[j].zero_() | |
| part_M = list2style(part_M) | |
| M[:, i] = part_M | |
| return M | |
| #==== | |
| # for i in range(len(blend_deltas)): | |
| # if blend_deltas[i] is not None: | |
| # print(f'{i}: {part_M_mask[i].sum()}/{sum(part_M_mask[i].shape)}') | |
| # if part_M_mask[i].sum() >= sum(part_M_mask[i].shape)/2: | |
| # print(i) | |
| # blend_deltas[i] = ref_deltas[i] | |
| def tensor2img(tensor): | |
| tensor = tensor.cpu().clamp(-1, 1) | |
| img = topil(tensor.squeeze()) | |
| return img | |
| def hair_transfer_hyperstyle(source_img_path, ref_img_path): | |
| with torch.no_grad(): | |
| source_img = align_face(source_img_path, predictor=predictor) | |
| ref_img = align_face(ref_img_path, predictor=predictor) | |
| source_img = Image.fromarray(np.uint8(source_img)) | |
| ref_img = Image.fromarray(np.uint8(ref_img)) | |
| source_tensor = transform(source_img).unsqueeze(0).to(device) | |
| ref_tensor = transform(ref_img).unsqueeze(0).to(device) | |
| source_batch, source_latent, source_deltas, source_codes = run_inversion(source_tensor, net, n_iters_per_batch=5, return_intermediate_results=False) | |
| ref_batch, ref_latent, ref_deltas, ref_codes = run_inversion(ref_tensor, net, n_iters_per_batch=5, return_intermediate_results=False) | |
| source = generator.get_latent(source_latent[0].unsqueeze(0), truncation=1, is_latent=True) | |
| ref = generator.get_latent(ref_latent[0].unsqueeze(0), truncation=1, is_latent=True) | |
| source_out, _ = generator(source, weights_deltas=source_deltas, randomize_noise=False) | |
| ref_out, _ = generator(ref, weights_deltas=ref_deltas, randomize_noise=False) | |
| source_M = compute_M(source, weights_deltas=source_deltas, device='cpu') | |
| ref_M = compute_M(ref, weights_deltas=ref_deltas, device='cpu') | |
| blend_deltas = source_deltas | |
| max_M = torch.max(source_M.expand_as(ref_M), ref_M) | |
| max_M = add_pose(max_M, labels2idx) | |
| idx = labels2idx['hair'] | |
| part_M = max_M[:, idx].to(device) | |
| part_M_mask = style2list(part_M) | |
| blend = style2list((add_direction(source, ref, part_M, 1.3))) | |
| blend_out, _ = generator(blend, weights_deltas=blend_deltas) | |
| source_out = tensor2img(source_out) | |
| ref_out = tensor2img(ref_out) | |
| blend_out = tensor2img(blend_out) | |
| lpips_face, _ = lpips(blend_out, source_out) | |
| ssim_face, _ = ssim(blend_out, source_out) | |
| id_face, _ = id_score(blend_out, source_out) | |
| _, lpips_hair = lpips(blend_out, ref_out) | |
| _, ssim_hair = ssim(blend_out, ref_out) | |
| _, clip_hair = clip(blend_out, source_out) | |
| out_str = f'lpips_face: {lpips_face}\nlpips_hair: {lpips_hair}\nssim_face: {ssim_face}\nssim_hair: {ssim_hair}\nid_face: {id_face}\n clip_hair: {clip_hair}' | |
| e4e_blend_out, _ = generator(blend) | |
| e4e_blend_out = tensor2img(e4e_blend_out) | |
| _, _, e4e_blend_hair_mask = lpips.parser(e4e_blend_out) | |
| source_out_np = np.array(source_out) | |
| blend_np =np.array(e4e_blend_out).astype(np.uint8) | |
| e4e_blend_hair_mask = e4e_blend_hair_mask.cpu().numpy().astype(np.uint8)*255 | |
| mask_dilate = cv2.dilate(e4e_blend_hair_mask, | |
| kernel=np.ones((50, 50), np.uint8)) | |
| mask_dilate_blur = cv2.blur(mask_dilate, ksize=(30, 30)) | |
| mask_dilate_blur = (e4e_blend_hair_mask + (255 - e4e_blend_hair_mask) / 255 * mask_dilate_blur).astype(np.uint8) | |
| face_mask = 255 - mask_dilate_blur | |
| index = np.where(face_mask > 0) | |
| cy = (np.min(index[0]) + np.max(index[0])) // 2 | |
| cx = (np.min(index[1]) + np.max(index[1])) // 2 | |
| center = (cx, cy) | |
| clone_out = cv2.seamlessClone(source_out_np, blend_np, face_mask, center, cv2.NORMAL_CLONE) | |
| return source_out, ref_out, blend_out, out_str, clone_out | |
| def hair_transfer_e4e(source_img_path, ref_img_path): | |
| with torch.no_grad(): | |
| source_img = align_face(source_img_path, predictor=predictor) | |
| ref_img = align_face(ref_img_path, predictor=predictor) | |
| source_img = Image.fromarray(np.uint8(source_img)) | |
| ref_img = Image.fromarray(np.uint8(ref_img)) | |
| source_tensor = transform(source_img).unsqueeze(0).to(device) | |
| ref_tensor = transform(ref_img).unsqueeze(0).to(device) | |
| source_batch, source_latent, source_deltas, source_codes = run_inversion(source_tensor, net, n_iters_per_batch=5, return_intermediate_results=False) | |
| ref_batch, ref_latent, ref_deltas, ref_codes = run_inversion(ref_tensor, net, n_iters_per_batch=5, return_intermediate_results=False) | |
| source = generator.get_latent(source_latent[0].unsqueeze(0), truncation=1, is_latent=True) | |
| ref = generator.get_latent(ref_latent[0].unsqueeze(0), truncation=1, is_latent=True) | |
| e4e_source_out, _ = generator(source, randomize_noise=False) | |
| e4e_ref_out, _ = generator(ref, randomize_noise=False) | |
| e4e_source_M = compute_M(source, device='cpu') | |
| e4e_ref_M = compute_M(ref, device='cpu') | |
| e4e_max_M = torch.max(e4e_source_M.expand_as(e4e_ref_M), e4e_ref_M) | |
| e4e_max_M = add_pose(e4e_max_M, labels2idx) | |
| e4e_idx = labels2idx['hair'] | |
| e4e_part_M = e4e_max_M[:, e4e_idx].to(device) | |
| e4e_part_M_mask = style2list(e4e_part_M) | |
| e4e_blend = style2list((add_direction(source, ref, e4e_part_M, 1.3))) | |
| e4e_blend_out, _ = generator(e4e_blend) | |
| e4e_source_out = tensor2img(e4e_source_out) | |
| e4e_ref_out = tensor2img(e4e_ref_out) | |
| e4e_blend_out = tensor2img(e4e_blend_out) | |
| e4e_lpips_face, _ = lpips(e4e_blend_out, e4e_source_out) | |
| e4e_ssim_face, _ = ssim(e4e_blend_out, e4e_source_out) | |
| e4e_id_face, _ = id_score(e4e_blend_out, e4e_source_out) | |
| _, e4e_lpips_hair = lpips(e4e_blend_out, e4e_ref_out) | |
| _, e4e_ssim_hair = ssim(e4e_blend_out, e4e_ref_out) | |
| _, e4e_clip_hair = clip(e4e_blend_out, e4e_source_out) | |
| e4e_out_str = f'e4e_lpips_face: {e4e_lpips_face}\ne4e_lpips_hair: {e4e_lpips_hair}\ne4e_ssim_face: {e4e_ssim_face}\ne4e_ssim_hair: {e4e_ssim_hair}\ne4e_id_face: {e4e_id_face}\ne4e_ clip_hair: {e4e_clip_hair}' | |
| return e4e_source_out, e4e_ref_out, e4e_blend_out, e4e_out_str | |
| def hair_transfer_PTI(source_img_path, ref_img_path): | |
| ckpt = 'pretrained/ffhq.pkl' | |
| G = Generator_wrapper(ckpt, device) | |
| manipulator = Manipulator(G, device) | |
| manipulator.set_real_img_projection(source_img_path, inv_mode='w+', pti_mode='s') | |
| with torch.no_grad(): | |
| source_img = align_face(source_img_path, predictor=predictor) | |
| ref_img = align_face(ref_img_path, predictor=predictor) | |
| source_img = Image.fromarray(np.uint8(source_img)) | |
| projection(source_img, 'source', generator, device) | |
| projection(ref_img, 'ref', generator, device) | |
| source = load_source('source', generator, device) | |
| ref = load_source('ref', generator, device) | |
| e4e_source_out, _ = generator(source, randomize_noise=False) | |
| e4e_ref_out, _ = generator(ref, randomize_noise=False) | |
| e4e_source_M = compute_M(source, device='cpu') | |
| e4e_ref_M = compute_M(ref, device='cpu') | |
| e4e_max_M = torch.max(e4e_source_M.expand_as(e4e_ref_M), e4e_ref_M) | |
| e4e_max_M = add_pose(e4e_max_M, labels2idx) | |
| e4e_idx = labels2idx['hair'] | |
| e4e_part_M = e4e_max_M[:, e4e_idx].to(device) | |
| e4e_part_M_mask = style2list(e4e_part_M) | |
| e4e_blend = style2list((add_direction(source, ref, e4e_part_M, 1.3))) | |
| e4e_source_out = tensor2img(e4e_source_out) | |
| e4e_ref_out = tensor2img(e4e_ref_out) | |
| # e4e_blend_out = tensor2img(e4e_blend_out) | |
| # e4e_lpips_face, _ = lpips(e4e_blend_out, e4e_source_out) | |
| # e4e_ssim_face, _ = ssim(e4e_blend_out, e4e_source_out) | |
| # e4e_id_face, _ = id_score(e4e_blend_out, e4e_source_out) | |
| # _, e4e_lpips_hair = lpips(e4e_blend_out, e4e_ref_out) | |
| # _, e4e_ssim_hair = ssim(e4e_blend_out, e4e_ref_out) | |
| # _, e4e_clip_hair = clip(e4e_blend_out, e4e_source_out) | |
| keys = (['G.synthesis.b4.conv1.affine', 'G.synthesis.b4.torgb.affine', 'G.synthesis.b8.conv0.affine', 'G.synthesis.b8.conv1.affine', 'G.synthesis.b8.torgb.affine', 'G.synthesis.b16.conv0.affine', 'G.synthesis.b16.conv1.affine', 'G.synthesis.b16.torgb.affine', 'G.synthesis.b32.conv0.affine', 'G.synthesis.b32.conv1.affine', 'G.synthesis.b32.torgb.affine', 'G.synthesis.b64.conv0.affine', 'G.synthesis.b64.conv1.affine', 'G.synthesis.b64.torgb.affine', 'G.synthesis.b128.conv0.affine', 'G.synthesis.b128.conv1.affine', 'G.synthesis.b128.torgb.affine', 'G.synthesis.b256.conv0.affine', 'G.synthesis.b256.conv1.affine', 'G.synthesis.b256.torgb.affine', 'G.synthesis.b512.conv0.affine', 'G.synthesis.b512.conv1.affine', 'G.synthesis.b512.torgb.affine', 'G.synthesis.b1024.conv0.affine', 'G.synthesis.b1024.conv1.affine', 'G.synthesis.b1024.torgb.affine']) | |
| test_dict = dict(zip(keys, e4e_blend)) | |
| manipulator_list = [] | |
| manipulator_list.append(test_dict) | |
| all_imgs = manipulator.synthesis_from_styles(manipulator_list, 0, 1) | |
| PTI_outstr = 'PTI_outstr' | |
| blend_out = tensor2img(all_imgs[0]) | |
| return e4e_source_out, e4e_ref_out, blend_out, PTI_outstr | |
| # _, _, e4e_blend_hair_mask = lpips.parser(e4e_blend_out) | |
| # blend_out_np = np.array(blend_out) | |
| # blend_np =np.array(e4e_blend_out).astype(np.uint8) | |
| # e4e_blend_hair_mask = e4e_blend_hair_mask.cpu().numpy().astype(np.uint8)*255 | |
| # mask_dilate = cv2.dilate(e4e_blend_hair_mask, | |
| # kernel=np.ones((50, 50), np.uint8)) | |
| # mask_dilate_blur = cv2.blur(mask_dilate, ksize=(30, 30)) | |
| # mask_dilate_blur = (e4e_blend_hair_mask + (255 - e4e_blend_hair_mask) / 255 * mask_dilate_blur).astype(np.uint8) | |
| # face_mask = 255 - mask_dilate_blur | |
| # index = np.where(face_mask > 0) | |
| # cy = (np.min(index[0]) + np.max(index[0])) // 2 | |
| # cx = (np.min(index[1]) + np.max(index[1])) // 2 | |
| # center = (cx, cy) | |
| # clone_out = cv2.seamlessClone(blend_out_np, blend_np, face_mask, center, cv2.NORMAL_CLONE) | |
| # out_str = f'lpips_face: {lpips_face}\nlpips_hair: {lpips_hair}\nssim_face: {ssim_face}\nssim_hair: {ssim_hair}\nid_face: {id_face}\n clip_hair: {clip_hair}' | |
| # seg_out = torch.tensor(face_mask).float().unsqueeze(-1).repeat(1,1,3) | |
| # seg_out = seg_out.cpu().numpy().astype(np.uint8) | |
| # # seg_out*=255 | |
| # seg_out = Image.fromarray(seg_out) | |
| # # return source_out, ref_out, blend_out, out_str, e4e_source_out, e4e_ref_out, e4e_blend_out, e4e_out_str, clone_out, seg_out | |
| # ## Set source_tensor requires_grad=True | |
| # source_tensor.requires_grad = True | |
| # ref_tensor.requires_grad = True | |
| # ckpt = 'pretrained/ffhq.pkl' | |
| # G = Generator_wrapper(ckpt, device) | |
| # manipulator = Manipulator(G, device) | |
| # manipulator.set_real_img_projection(source_img_path, inv_mode='w+', pti_mode='s') | |
| # blend = style2list((add_direction(source_tensor, ref_tensor, part_M, 1.3))) | |
| # keys = (['G.synthesis.b4.conv1.affine', 'G.synthesis.b4.torgb.affine', 'G.synthesis.b8.conv0.affine', 'G.synthesis.b8.conv1.affine', 'G.synthesis.b8.torgb.affine', 'G.synthesis.b16.conv0.affine', 'G.synthesis.b16.conv1.affine', 'G.synthesis.b16.torgb.affine', 'G.synthesis.b32.conv0.affine', 'G.synthesis.b32.conv1.affine', 'G.synthesis.b32.torgb.affine', 'G.synthesis.b64.conv0.affine', 'G.synthesis.b64.conv1.affine', 'G.synthesis.b64.torgb.affine', 'G.synthesis.b128.conv0.affine', 'G.synthesis.b128.conv1.affine', 'G.synthesis.b128.torgb.affine', 'G.synthesis.b256.conv0.affine', 'G.synthesis.b256.conv1.affine', 'G.synthesis.b256.torgb.affine', 'G.synthesis.b512.conv0.affine', 'G.synthesis.b512.conv1.affine', 'G.synthesis.b512.torgb.affine', 'G.synthesis.b1024.conv0.affine', 'G.synthesis.b1024.conv1.affine', 'G.synthesis.b1024.torgb.affine']) | |
| # test_dict = dict(zip(keys, blend)) | |
| # manipulator_list = [] | |
| # manipulator_list.append(test_dict) | |
| # all_imgs = manipulator.synthesis_from_styles(manipulator_list, 0, 1) | |
| # return source_out, ref_out, blend_out, out_str, e4e_source_out, e4e_ref_out, e4e_blend_out, e4e_out_str, clone_out, all_imgs | |
| ## argument for choosing encoder between e4e and hyperstyle | |
| args = argparse.ArgumentParser() | |
| args.add_argument('--encoder', type=str, default='hyperstyle') | |
| opt = args.parse_args() | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| lpips = FaceMetric(metric_type='lpips', device=device) | |
| ssim = FaceMetric(metric_type='ms-ssim', device=device) | |
| id_score = FaceMetric(metric_type='id', device=device) | |
| clip = FaceMetric(metric_type='cliphair', device=device) | |
| # generator = Generator(1024, 512, 8, channel_multiplier=2).to(device).eval() | |
| # ckpt = torch.load('stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage) | |
| # generator.load_state_dict(ckpt['g_ema'], strict=False) | |
| generator = Generator(1024, 512, 8, channel_multiplier=2).to(device).eval() | |
| ckpt = torch.load('stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage) | |
| generator.load_state_dict(ckpt['g_ema'], strict=False) | |
| ckpt = 'pretrained/ffhq.pkl' | |
| G = Generator_wrapper(ckpt, device) | |
| manipulator = Manipulator(G, device) | |
| if opt.encoder == 'e4e': | |
| from util import align_face | |
| model_path = 'e4e_ffhq_encode.pt' | |
| ensure_checkpoint_exists(model_path) | |
| ckpt = torch.load(model_path, map_location='cpu') | |
| opts = ckpt['opts'] | |
| opts['checkpoint_path'] = model_path | |
| opts= Namespace(**opts) | |
| net = pSp(opts, device).eval().to(device) | |
| elif opt.encoder == 'hyperstyle': | |
| from hyperstyle.scripts.align_faces_parallel import align_face | |
| model_path = 'pretrained_models/hyperstyle_ffhq.pt' | |
| predictor = dlib.shape_predictor('pretrained_models/shape_predictor_68_face_landmarks.dat') | |
| net, _ = load_model(model_path) | |
| else: | |
| raise ValueError('invalid encoder') | |
| truncation = 0.5 | |
| stop_idx = 11 | |
| n_clusters = 18 | |
| clusterer = pickle.load(open('catalog.pkl', 'rb')) | |
| labels2idx = { | |
| 'nose': 0, | |
| 'eyes': 1, | |
| 'mouth': 2, | |
| 'hair': 3, | |
| 'background': 4, | |
| 'cheek': 5, | |
| 'neck': 6, | |
| 'clothes': 7, | |
| } | |
| labels_map = { | |
| 0: torch.tensor([7]), | |
| 1: torch.tensor([1,6]), | |
| 2: torch.tensor([4]), | |
| 3: torch.tensor([0,3,5,8,10,15,16]), | |
| 4: torch.tensor([11,13,14]), | |
| 5: torch.tensor([9]), | |
| 6: torch.tensor([17]), | |
| 7: torch.tensor([2,12]), | |
| } | |
| lables2idx = dict((v,k) for k,v in labels2idx.items()) | |
| n_class = len(lables2idx) | |
| segid_map = dict.fromkeys(labels_map[0].tolist(), 0) | |
| segid_map.update(dict.fromkeys(labels_map[1].tolist(), 1)) | |
| segid_map.update(dict.fromkeys(labels_map[2].tolist(), 2)) | |
| segid_map.update(dict.fromkeys(labels_map[3].tolist(), 3)) | |
| segid_map.update(dict.fromkeys(labels_map[4].tolist(), 4)) | |
| segid_map.update(dict.fromkeys(labels_map[5].tolist(), 5)) | |
| segid_map.update(dict.fromkeys(labels_map[6].tolist(), 6)) | |
| segid_map.update(dict.fromkeys(labels_map[7].tolist(), 7)) | |
| torch.manual_seed(0) | |
| transform = transforms.Compose( | |
| [ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(256), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), | |
| ] | |
| ) | |
| topil = transforms.Compose( | |
| [ | |
| transforms.Normalize([-1, -1, -1], [2, 2, 2]), | |
| transforms.ToPILImage(), | |
| transforms.Resize(1024), | |
| ] | |
| ) | |
| e4e_ris_demo = gr.Interface(hair_transfer_e4e, inputs=[gr.Image(type='filepath'),gr.Image(type='filepath')], outputs=["image","image","image","text"]) | |
| hyperstyle_ris_demo = gr.Interface(hair_transfer_hyperstyle, inputs=[gr.Image(type='filepath'),gr.Image(type='filepath')], outputs=["image","image","image","text", "image"]) | |
| PTI_ris_demo = gr.Interface(hair_transfer_PTI, inputs=[gr.Image(type='filepath'),gr.Image(type='filepath')], outputs=["image","image","image","text"]) | |
| ris_demo = gr.TabbedInterface(interface_list = [hyperstyle_ris_demo,e4e_ris_demo, PTI_ris_demo], tab_names=["hyperstyle", "e4e", "PTI"]) | |
| ris_demo.launch(share=True) | |