Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
import torch.backends.cudnn as cudnn | |
cudnn.benchmark = True | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import os | |
import sys | |
from tqdm import tqdm as tqdm | |
import pickle | |
import warnings | |
warnings.filterwarnings("ignore") | |
from spherical_kmeans import MiniBatchSphericalKMeans as sKmeans | |
from argparse import Namespace | |
from torch.utils.data import Dataset, DataLoader | |
from torchvision.models import vgg19 | |
import glob | |
from pathlib import Path | |
import lpips | |
import argparse | |
import gc | |
import cv2 | |
from e4e_projection import projection | |
from model import * | |
from util import * | |
from e4e.models.psp import pSp | |
import torchvision.transforms as transforms | |
from torch.nn import DataParallel | |
import torchvision.transforms.functional as TF | |
from FaceQualityMetrics.utils import FaceMetric | |
from hyperstyle.utils.model_utils import load_model | |
from configs.paths_config import model_paths | |
from manipulator import Manipulator | |
from wrapper import Generator_wrapper | |
def run_inversion(inputs, net, n_iters_per_batch, return_intermediate_results=False, resize_outputs=False, weights_deltas=None): | |
y_hat, latent, weights_deltas, codes = None, None, weights_deltas, None | |
if return_intermediate_results: | |
results_batch = {idx: [] for idx in range(inputs.shape[0])} | |
results_latent = {idx: [] for idx in range(inputs.shape[0])} | |
results_deltas = {idx: [] for idx in range(inputs.shape[0])} | |
else: | |
results_batch, results_latent, results_deltas = None, None, None | |
if weights_deltas is None: | |
for iter in range(n_iters_per_batch): | |
y_hat, latent, weights_deltas, codes, _ = net.forward(inputs, | |
y_hat=y_hat, | |
codes=codes, | |
weights_deltas=weights_deltas, | |
return_latents=True, | |
resize=resize_outputs, | |
randomize_noise=False, | |
return_weight_deltas_and_codes=True) | |
# weights_deltas[14]= None | |
# weights_deltas[20]= None | |
# weights_deltas[21]= None | |
# weights_deltas[23]= None | |
# weights_deltas[24]= None | |
if return_intermediate_results: | |
store_intermediate_results(results_batch, results_latent, results_deltas, y_hat, latent, weights_deltas) | |
# resize input to 256 before feeding into next iteration | |
y_hat = net.face_pool(y_hat) | |
else: | |
for iter in range(n_iters_per_batch): | |
y_hat, latent, _, codes, _ = net.forward(inputs, | |
y_hat=y_hat, | |
codes=codes, | |
weights_deltas=weights_deltas, | |
return_latents=True, | |
resize=resize_outputs, | |
randomize_noise=False, | |
return_weight_deltas_and_codes=True) | |
if return_intermediate_results: | |
store_intermediate_results(results_batch, results_latent, results_deltas, y_hat, latent, weights_deltas) | |
# resize input to 256 before feeding into next iteration | |
y_hat = net.face_pool(y_hat) | |
if return_intermediate_results: | |
return results_batch, results_latent, results_deltas | |
return y_hat, latent, weights_deltas, codes | |
def store_intermediate_results(results_batch, results_latent, results_deltas, y_hat, latent, weights_deltas): | |
for idx in range(y_hat.shape[0]): | |
results_batch[idx].append(y_hat[idx]) | |
results_latent[idx].append(latent[idx].cpu().numpy()) | |
results_deltas[idx].append([w[idx].cpu().numpy() if w is not None else None for w in weights_deltas]) | |
# compute M given a style code. | |
def compute_M(w, weights_deltas=None, device='cuda'): | |
M = [] | |
# get segmentation | |
# _, outputs = generator(w, is_cluster=1) | |
_, outputs = generator(w, weights_deltas=weights_deltas) | |
cluster_layer = outputs[stop_idx][0] | |
activation = flatten_act(cluster_layer) | |
seg_mask = clusterer.predict(activation) | |
b,c,h,w = cluster_layer.size() | |
# create masks for each feature | |
all_seg_mask = [] | |
seg_mask = torch.from_numpy(seg_mask).view(b,1,h,w,1).to(device) | |
for key in range(n_class): | |
# combine masks for all indices for a particular segmentation class | |
indices = labels_map[key].view(1,1,1,1,-1) | |
key_mask = (seg_mask == indices.to(device)).any(-1) #[b,1,h,w] | |
all_seg_mask.append(key_mask) | |
all_seg_mask = torch.stack(all_seg_mask, 1) | |
# go through each activation layer and compute M | |
for layer_idx in range(len(outputs)): | |
layer = outputs[layer_idx][1].to(device) | |
b,c,h,w = layer.size() | |
layer = F.instance_norm(layer) | |
layer = layer.pow(2) | |
# resize the segmentation masks to current activations' resolution | |
layer_seg_mask = F.interpolate(all_seg_mask.flatten(0,1).float(), align_corners=False, | |
size=(h,w), mode='bilinear').view(b,-1,1,h,w) | |
masked_layer = layer.unsqueeze(1) * layer_seg_mask # [b,k,c,h,w] | |
masked_layer = (masked_layer.sum([3,4])/ (h*w))#[b,k,c] | |
M.append(masked_layer.to(device)) | |
M = torch.cat(M, -1) #[b, k, c] | |
# softmax to assign each channel to a particular segmentation class | |
M = F.softmax(M/.1, 1) | |
# simple thresholding | |
M = (M>.8).float() | |
# zero out torgb transfers, from https://arxiv.org/abs/2011.12799 | |
for i in range(n_class): | |
part_M = style2list(M[:, i]) | |
for j in range(len(part_M)): | |
if j in rgb_layer_idx: | |
part_M[j].zero_() | |
part_M = list2style(part_M) | |
M[:, i] = part_M | |
return M | |
#==== | |
# for i in range(len(blend_deltas)): | |
# if blend_deltas[i] is not None: | |
# print(f'{i}: {part_M_mask[i].sum()}/{sum(part_M_mask[i].shape)}') | |
# if part_M_mask[i].sum() >= sum(part_M_mask[i].shape)/2: | |
# print(i) | |
# blend_deltas[i] = ref_deltas[i] | |
def tensor2img(tensor): | |
tensor = tensor.cpu().clamp(-1, 1) | |
img = topil(tensor.squeeze()) | |
return img | |
def hair_transfer_hyperstyle(source_img_path, ref_img_path): | |
with torch.no_grad(): | |
source_img = align_face(source_img_path, predictor=predictor) | |
ref_img = align_face(ref_img_path, predictor=predictor) | |
source_img = Image.fromarray(np.uint8(source_img)) | |
ref_img = Image.fromarray(np.uint8(ref_img)) | |
source_tensor = transform(source_img).unsqueeze(0).to(device) | |
ref_tensor = transform(ref_img).unsqueeze(0).to(device) | |
source_batch, source_latent, source_deltas, source_codes = run_inversion(source_tensor, net, n_iters_per_batch=5, return_intermediate_results=False) | |
ref_batch, ref_latent, ref_deltas, ref_codes = run_inversion(ref_tensor, net, n_iters_per_batch=5, return_intermediate_results=False) | |
source = generator.get_latent(source_latent[0].unsqueeze(0), truncation=1, is_latent=True) | |
ref = generator.get_latent(ref_latent[0].unsqueeze(0), truncation=1, is_latent=True) | |
source_out, _ = generator(source, weights_deltas=source_deltas, randomize_noise=False) | |
ref_out, _ = generator(ref, weights_deltas=ref_deltas, randomize_noise=False) | |
source_M = compute_M(source, weights_deltas=source_deltas, device='cpu') | |
ref_M = compute_M(ref, weights_deltas=ref_deltas, device='cpu') | |
blend_deltas = source_deltas | |
max_M = torch.max(source_M.expand_as(ref_M), ref_M) | |
max_M = add_pose(max_M, labels2idx) | |
idx = labels2idx['hair'] | |
part_M = max_M[:, idx].to(device) | |
part_M_mask = style2list(part_M) | |
blend = style2list((add_direction(source, ref, part_M, 1.3))) | |
blend_out, _ = generator(blend, weights_deltas=blend_deltas) | |
source_out = tensor2img(source_out) | |
ref_out = tensor2img(ref_out) | |
blend_out = tensor2img(blend_out) | |
lpips_face, _ = lpips(blend_out, source_out) | |
ssim_face, _ = ssim(blend_out, source_out) | |
id_face, _ = id_score(blend_out, source_out) | |
_, lpips_hair = lpips(blend_out, ref_out) | |
_, ssim_hair = ssim(blend_out, ref_out) | |
_, clip_hair = clip(blend_out, source_out) | |
out_str = f'lpips_face: {lpips_face}\nlpips_hair: {lpips_hair}\nssim_face: {ssim_face}\nssim_hair: {ssim_hair}\nid_face: {id_face}\n clip_hair: {clip_hair}' | |
e4e_blend_out, _ = generator(blend) | |
e4e_blend_out = tensor2img(e4e_blend_out) | |
_, _, e4e_blend_hair_mask = lpips.parser(e4e_blend_out) | |
source_out_np = np.array(source_out) | |
blend_np =np.array(e4e_blend_out).astype(np.uint8) | |
e4e_blend_hair_mask = e4e_blend_hair_mask.cpu().numpy().astype(np.uint8)*255 | |
mask_dilate = cv2.dilate(e4e_blend_hair_mask, | |
kernel=np.ones((50, 50), np.uint8)) | |
mask_dilate_blur = cv2.blur(mask_dilate, ksize=(30, 30)) | |
mask_dilate_blur = (e4e_blend_hair_mask + (255 - e4e_blend_hair_mask) / 255 * mask_dilate_blur).astype(np.uint8) | |
face_mask = 255 - mask_dilate_blur | |
index = np.where(face_mask > 0) | |
cy = (np.min(index[0]) + np.max(index[0])) // 2 | |
cx = (np.min(index[1]) + np.max(index[1])) // 2 | |
center = (cx, cy) | |
clone_out = cv2.seamlessClone(source_out_np, blend_np, face_mask, center, cv2.NORMAL_CLONE) | |
return source_out, ref_out, blend_out, out_str, clone_out | |
def hair_transfer_e4e(source_img_path, ref_img_path): | |
with torch.no_grad(): | |
source_img = align_face(source_img_path, predictor=predictor) | |
ref_img = align_face(ref_img_path, predictor=predictor) | |
source_img = Image.fromarray(np.uint8(source_img)) | |
ref_img = Image.fromarray(np.uint8(ref_img)) | |
source_tensor = transform(source_img).unsqueeze(0).to(device) | |
ref_tensor = transform(ref_img).unsqueeze(0).to(device) | |
source_batch, source_latent, source_deltas, source_codes = run_inversion(source_tensor, net, n_iters_per_batch=5, return_intermediate_results=False) | |
ref_batch, ref_latent, ref_deltas, ref_codes = run_inversion(ref_tensor, net, n_iters_per_batch=5, return_intermediate_results=False) | |
source = generator.get_latent(source_latent[0].unsqueeze(0), truncation=1, is_latent=True) | |
ref = generator.get_latent(ref_latent[0].unsqueeze(0), truncation=1, is_latent=True) | |
e4e_source_out, _ = generator(source, randomize_noise=False) | |
e4e_ref_out, _ = generator(ref, randomize_noise=False) | |
e4e_source_M = compute_M(source, device='cpu') | |
e4e_ref_M = compute_M(ref, device='cpu') | |
e4e_max_M = torch.max(e4e_source_M.expand_as(e4e_ref_M), e4e_ref_M) | |
e4e_max_M = add_pose(e4e_max_M, labels2idx) | |
e4e_idx = labels2idx['hair'] | |
e4e_part_M = e4e_max_M[:, e4e_idx].to(device) | |
e4e_part_M_mask = style2list(e4e_part_M) | |
e4e_blend = style2list((add_direction(source, ref, e4e_part_M, 1.3))) | |
e4e_blend_out, _ = generator(e4e_blend) | |
e4e_source_out = tensor2img(e4e_source_out) | |
e4e_ref_out = tensor2img(e4e_ref_out) | |
e4e_blend_out = tensor2img(e4e_blend_out) | |
e4e_lpips_face, _ = lpips(e4e_blend_out, e4e_source_out) | |
e4e_ssim_face, _ = ssim(e4e_blend_out, e4e_source_out) | |
e4e_id_face, _ = id_score(e4e_blend_out, e4e_source_out) | |
_, e4e_lpips_hair = lpips(e4e_blend_out, e4e_ref_out) | |
_, e4e_ssim_hair = ssim(e4e_blend_out, e4e_ref_out) | |
_, e4e_clip_hair = clip(e4e_blend_out, e4e_source_out) | |
e4e_out_str = f'e4e_lpips_face: {e4e_lpips_face}\ne4e_lpips_hair: {e4e_lpips_hair}\ne4e_ssim_face: {e4e_ssim_face}\ne4e_ssim_hair: {e4e_ssim_hair}\ne4e_id_face: {e4e_id_face}\ne4e_ clip_hair: {e4e_clip_hair}' | |
return e4e_source_out, e4e_ref_out, e4e_blend_out, e4e_out_str | |
def hair_transfer_PTI(source_img_path, ref_img_path): | |
ckpt = 'pretrained/ffhq.pkl' | |
G = Generator_wrapper(ckpt, device) | |
manipulator = Manipulator(G, device) | |
manipulator.set_real_img_projection(source_img_path, inv_mode='w+', pti_mode='s') | |
with torch.no_grad(): | |
source_img = align_face(source_img_path, predictor=predictor) | |
ref_img = align_face(ref_img_path, predictor=predictor) | |
source_img = Image.fromarray(np.uint8(source_img)) | |
projection(source_img, 'source', generator, device) | |
projection(ref_img, 'ref', generator, device) | |
source = load_source('source', generator, device) | |
ref = load_source('ref', generator, device) | |
e4e_source_out, _ = generator(source, randomize_noise=False) | |
e4e_ref_out, _ = generator(ref, randomize_noise=False) | |
e4e_source_M = compute_M(source, device='cpu') | |
e4e_ref_M = compute_M(ref, device='cpu') | |
e4e_max_M = torch.max(e4e_source_M.expand_as(e4e_ref_M), e4e_ref_M) | |
e4e_max_M = add_pose(e4e_max_M, labels2idx) | |
e4e_idx = labels2idx['hair'] | |
e4e_part_M = e4e_max_M[:, e4e_idx].to(device) | |
e4e_part_M_mask = style2list(e4e_part_M) | |
e4e_blend = style2list((add_direction(source, ref, e4e_part_M, 1.3))) | |
e4e_source_out = tensor2img(e4e_source_out) | |
e4e_ref_out = tensor2img(e4e_ref_out) | |
# e4e_blend_out = tensor2img(e4e_blend_out) | |
# e4e_lpips_face, _ = lpips(e4e_blend_out, e4e_source_out) | |
# e4e_ssim_face, _ = ssim(e4e_blend_out, e4e_source_out) | |
# e4e_id_face, _ = id_score(e4e_blend_out, e4e_source_out) | |
# _, e4e_lpips_hair = lpips(e4e_blend_out, e4e_ref_out) | |
# _, e4e_ssim_hair = ssim(e4e_blend_out, e4e_ref_out) | |
# _, e4e_clip_hair = clip(e4e_blend_out, e4e_source_out) | |
keys = (['G.synthesis.b4.conv1.affine', 'G.synthesis.b4.torgb.affine', 'G.synthesis.b8.conv0.affine', 'G.synthesis.b8.conv1.affine', 'G.synthesis.b8.torgb.affine', 'G.synthesis.b16.conv0.affine', 'G.synthesis.b16.conv1.affine', 'G.synthesis.b16.torgb.affine', 'G.synthesis.b32.conv0.affine', 'G.synthesis.b32.conv1.affine', 'G.synthesis.b32.torgb.affine', 'G.synthesis.b64.conv0.affine', 'G.synthesis.b64.conv1.affine', 'G.synthesis.b64.torgb.affine', 'G.synthesis.b128.conv0.affine', 'G.synthesis.b128.conv1.affine', 'G.synthesis.b128.torgb.affine', 'G.synthesis.b256.conv0.affine', 'G.synthesis.b256.conv1.affine', 'G.synthesis.b256.torgb.affine', 'G.synthesis.b512.conv0.affine', 'G.synthesis.b512.conv1.affine', 'G.synthesis.b512.torgb.affine', 'G.synthesis.b1024.conv0.affine', 'G.synthesis.b1024.conv1.affine', 'G.synthesis.b1024.torgb.affine']) | |
test_dict = dict(zip(keys, e4e_blend)) | |
manipulator_list = [] | |
manipulator_list.append(test_dict) | |
all_imgs = manipulator.synthesis_from_styles(manipulator_list, 0, 1) | |
PTI_outstr = 'PTI_outstr' | |
blend_out = tensor2img(all_imgs[0]) | |
return e4e_source_out, e4e_ref_out, blend_out, PTI_outstr | |
# _, _, e4e_blend_hair_mask = lpips.parser(e4e_blend_out) | |
# blend_out_np = np.array(blend_out) | |
# blend_np =np.array(e4e_blend_out).astype(np.uint8) | |
# e4e_blend_hair_mask = e4e_blend_hair_mask.cpu().numpy().astype(np.uint8)*255 | |
# mask_dilate = cv2.dilate(e4e_blend_hair_mask, | |
# kernel=np.ones((50, 50), np.uint8)) | |
# mask_dilate_blur = cv2.blur(mask_dilate, ksize=(30, 30)) | |
# mask_dilate_blur = (e4e_blend_hair_mask + (255 - e4e_blend_hair_mask) / 255 * mask_dilate_blur).astype(np.uint8) | |
# face_mask = 255 - mask_dilate_blur | |
# index = np.where(face_mask > 0) | |
# cy = (np.min(index[0]) + np.max(index[0])) // 2 | |
# cx = (np.min(index[1]) + np.max(index[1])) // 2 | |
# center = (cx, cy) | |
# clone_out = cv2.seamlessClone(blend_out_np, blend_np, face_mask, center, cv2.NORMAL_CLONE) | |
# out_str = f'lpips_face: {lpips_face}\nlpips_hair: {lpips_hair}\nssim_face: {ssim_face}\nssim_hair: {ssim_hair}\nid_face: {id_face}\n clip_hair: {clip_hair}' | |
# seg_out = torch.tensor(face_mask).float().unsqueeze(-1).repeat(1,1,3) | |
# seg_out = seg_out.cpu().numpy().astype(np.uint8) | |
# # seg_out*=255 | |
# seg_out = Image.fromarray(seg_out) | |
# # return source_out, ref_out, blend_out, out_str, e4e_source_out, e4e_ref_out, e4e_blend_out, e4e_out_str, clone_out, seg_out | |
# ## Set source_tensor requires_grad=True | |
# source_tensor.requires_grad = True | |
# ref_tensor.requires_grad = True | |
# ckpt = 'pretrained/ffhq.pkl' | |
# G = Generator_wrapper(ckpt, device) | |
# manipulator = Manipulator(G, device) | |
# manipulator.set_real_img_projection(source_img_path, inv_mode='w+', pti_mode='s') | |
# blend = style2list((add_direction(source_tensor, ref_tensor, part_M, 1.3))) | |
# keys = (['G.synthesis.b4.conv1.affine', 'G.synthesis.b4.torgb.affine', 'G.synthesis.b8.conv0.affine', 'G.synthesis.b8.conv1.affine', 'G.synthesis.b8.torgb.affine', 'G.synthesis.b16.conv0.affine', 'G.synthesis.b16.conv1.affine', 'G.synthesis.b16.torgb.affine', 'G.synthesis.b32.conv0.affine', 'G.synthesis.b32.conv1.affine', 'G.synthesis.b32.torgb.affine', 'G.synthesis.b64.conv0.affine', 'G.synthesis.b64.conv1.affine', 'G.synthesis.b64.torgb.affine', 'G.synthesis.b128.conv0.affine', 'G.synthesis.b128.conv1.affine', 'G.synthesis.b128.torgb.affine', 'G.synthesis.b256.conv0.affine', 'G.synthesis.b256.conv1.affine', 'G.synthesis.b256.torgb.affine', 'G.synthesis.b512.conv0.affine', 'G.synthesis.b512.conv1.affine', 'G.synthesis.b512.torgb.affine', 'G.synthesis.b1024.conv0.affine', 'G.synthesis.b1024.conv1.affine', 'G.synthesis.b1024.torgb.affine']) | |
# test_dict = dict(zip(keys, blend)) | |
# manipulator_list = [] | |
# manipulator_list.append(test_dict) | |
# all_imgs = manipulator.synthesis_from_styles(manipulator_list, 0, 1) | |
# return source_out, ref_out, blend_out, out_str, e4e_source_out, e4e_ref_out, e4e_blend_out, e4e_out_str, clone_out, all_imgs | |
## argument for choosing encoder between e4e and hyperstyle | |
args = argparse.ArgumentParser() | |
args.add_argument('--encoder', type=str, default='hyperstyle') | |
opt = args.parse_args() | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
lpips = FaceMetric(metric_type='lpips', device=device) | |
ssim = FaceMetric(metric_type='ms-ssim', device=device) | |
id_score = FaceMetric(metric_type='id', device=device) | |
clip = FaceMetric(metric_type='cliphair', device=device) | |
# generator = Generator(1024, 512, 8, channel_multiplier=2).to(device).eval() | |
# ckpt = torch.load('stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage) | |
# generator.load_state_dict(ckpt['g_ema'], strict=False) | |
generator = Generator(1024, 512, 8, channel_multiplier=2).to(device).eval() | |
ckpt = torch.load('stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage) | |
generator.load_state_dict(ckpt['g_ema'], strict=False) | |
ckpt = 'pretrained/ffhq.pkl' | |
G = Generator_wrapper(ckpt, device) | |
manipulator = Manipulator(G, device) | |
if opt.encoder == 'e4e': | |
from util import align_face | |
model_path = 'e4e_ffhq_encode.pt' | |
ensure_checkpoint_exists(model_path) | |
ckpt = torch.load(model_path, map_location='cpu') | |
opts = ckpt['opts'] | |
opts['checkpoint_path'] = model_path | |
opts= Namespace(**opts) | |
net = pSp(opts, device).eval().to(device) | |
elif opt.encoder == 'hyperstyle': | |
from hyperstyle.scripts.align_faces_parallel import align_face | |
model_path = 'pretrained_models/hyperstyle_ffhq.pt' | |
predictor = dlib.shape_predictor('pretrained_models/shape_predictor_68_face_landmarks.dat') | |
net, _ = load_model(model_path) | |
else: | |
raise ValueError('invalid encoder') | |
truncation = 0.5 | |
stop_idx = 11 | |
n_clusters = 18 | |
clusterer = pickle.load(open('catalog.pkl', 'rb')) | |
labels2idx = { | |
'nose': 0, | |
'eyes': 1, | |
'mouth': 2, | |
'hair': 3, | |
'background': 4, | |
'cheek': 5, | |
'neck': 6, | |
'clothes': 7, | |
} | |
labels_map = { | |
0: torch.tensor([7]), | |
1: torch.tensor([1,6]), | |
2: torch.tensor([4]), | |
3: torch.tensor([0,3,5,8,10,15,16]), | |
4: torch.tensor([11,13,14]), | |
5: torch.tensor([9]), | |
6: torch.tensor([17]), | |
7: torch.tensor([2,12]), | |
} | |
lables2idx = dict((v,k) for k,v in labels2idx.items()) | |
n_class = len(lables2idx) | |
segid_map = dict.fromkeys(labels_map[0].tolist(), 0) | |
segid_map.update(dict.fromkeys(labels_map[1].tolist(), 1)) | |
segid_map.update(dict.fromkeys(labels_map[2].tolist(), 2)) | |
segid_map.update(dict.fromkeys(labels_map[3].tolist(), 3)) | |
segid_map.update(dict.fromkeys(labels_map[4].tolist(), 4)) | |
segid_map.update(dict.fromkeys(labels_map[5].tolist(), 5)) | |
segid_map.update(dict.fromkeys(labels_map[6].tolist(), 6)) | |
segid_map.update(dict.fromkeys(labels_map[7].tolist(), 7)) | |
torch.manual_seed(0) | |
transform = transforms.Compose( | |
[ | |
transforms.Resize(256), | |
transforms.CenterCrop(256), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), | |
] | |
) | |
topil = transforms.Compose( | |
[ | |
transforms.Normalize([-1, -1, -1], [2, 2, 2]), | |
transforms.ToPILImage(), | |
transforms.Resize(1024), | |
] | |
) | |
e4e_ris_demo = gr.Interface(hair_transfer_e4e, inputs=[gr.Image(type='filepath'),gr.Image(type='filepath')], outputs=["image","image","image","text"]) | |
hyperstyle_ris_demo = gr.Interface(hair_transfer_hyperstyle, inputs=[gr.Image(type='filepath'),gr.Image(type='filepath')], outputs=["image","image","image","text", "image"]) | |
PTI_ris_demo = gr.Interface(hair_transfer_PTI, inputs=[gr.Image(type='filepath'),gr.Image(type='filepath')], outputs=["image","image","image","text"]) | |
ris_demo = gr.TabbedInterface(interface_list = [hyperstyle_ris_demo,e4e_ris_demo, PTI_ris_demo], tab_names=["hyperstyle", "e4e", "PTI"]) | |
ris_demo.launch(share=True) | |