Spaces:
Runtime error
Runtime error
import os | |
from PIL import Image | |
import torch | |
import gradio as gr | |
torch.backends.cudnn.benchmark = True | |
import math | |
import random | |
import numpy as np | |
from torch import nn, autograd, optim | |
from torch.nn import functional as F | |
from tqdm import tqdm | |
import lpips | |
import time | |
from copy import deepcopy | |
import imageio | |
import sys | |
from PIL import Image | |
import torchvision.transforms as transforms | |
from argparse import Namespace | |
from e4e.utils.common import tensor2im | |
from e4e.models.psp import pSp | |
from e4e.models.encoders import psp_encoders | |
from e4e.models.stylegan2.model import Generator | |
from huggingface_hub import hf_hub_download | |
import dlib | |
from e4e.utils.alignment import align_face | |
transform = transforms.Compose([ | |
transforms.Resize((256, 256)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) | |
resize_dims = (256, 256) | |
device= 'cpu' | |
ffhq_model_path = hf_hub_download(repo_id="bankholdup/stylegan_petbreeder", filename="e4e_ffhq512.pt") | |
ffhq_ckpt = torch.load(ffhq_model_path, map_location='cpu') | |
ffhq_latent_avg = ffhq_ckpt['latent_avg'].to(device) | |
ffhq_opts = ffhq_ckpt['opts'] | |
ffhq_opts['checkpoint_path'] = ffhq_model_path | |
ffhq_opts= Namespace(**ffhq_opts) | |
ffhq_encoder = psp_encoders.Encoder4Editing(50, 'ir_se', ffhq_opts) | |
ffhq_e_filt = {k[len('encoder') + 1:]: v for k, v in ffhq_ckpt['state_dict'].items() if k[:len('encoder')] == 'encoder'} | |
ffhq_encoder.load_state_dict(ffhq_e_filt, strict=True) | |
ffhq_encoder.eval() | |
ffhq_encoder.to(device) | |
ffhq_decoder = Generator(512, 512, 8, channel_multiplier=2) | |
ffhq_d_filt = {k[len('decoder') + 1:]: v for k, v in ffhq_ckpt['state_dict'].items() if k[:len('decoder')] == 'decoder'} | |
ffhq_decoder.load_state_dict(ffhq_d_filt, strict=True) | |
ffhq_decoder.eval() | |
ffhq_decoder.to(device) | |
dog_model_path = hf_hub_download(repo_id="bankholdup/stylegan_petbreeder", filename="e4e_ffhq512_dog.pt") | |
dog_ckpt = torch.load(dog_model_path, map_location='cpu') | |
dog_latent_avg = dog_ckpt['latent_avg'].to(device) | |
dog_opts = dog_ckpt['opts'] | |
dog_opts['checkpoint_path'] = dog_model_path | |
dog_opts= Namespace(**dog_opts) | |
dog_encoder = psp_encoders.Encoder4Editing(50, 'ir_se', dog_opts) | |
dog_e_filt = {k[len('encoder') + 1:]: v for k, v in dog_ckpt['state_dict'].items() if k[:len('encoder')] == 'encoder'} | |
dog_encoder.load_state_dict(dog_e_filt, strict=True) | |
dog_encoder.eval() | |
dog_encoder.to(device) | |
dog_decoder = Generator(512, 512, 8, channel_multiplier=2) | |
dog_d_filt = {k[len('decoder') + 1:]: v for k, v in dog_ckpt['state_dict'].items() if k[:len('decoder')] == 'decoder'} | |
dog_decoder.load_state_dict(dog_d_filt, strict=True) | |
dog_decoder.eval() | |
dog_decoder.to(device) | |
cat_model_path = hf_hub_download(repo_id="bankholdup/stylegan_petbreeder", filename="e4e_ffhq512_cat.pt") | |
cat_ckpt = torch.load(cat_model_path, map_location='cpu') | |
cat_latent_avg = cat_ckpt['latent_avg'].to(device) | |
cat_opts = cat_ckpt['opts'] | |
cat_opts['checkpoint_path'] = cat_model_path | |
cat_opts= Namespace(**cat_opts) | |
cat_encoder = psp_encoders.Encoder4Editing(50, 'ir_se', cat_opts) | |
cat_e_filt = {k[len('encoder') + 1:]: v for k, v in cat_ckpt['state_dict'].items() if k[:len('encoder')] == 'encoder'} | |
cat_encoder.load_state_dict(cat_e_filt, strict=True) | |
cat_encoder.eval() | |
cat_encoder.to(device) | |
cat_decoder = Generator(512, 512, 8, channel_multiplier=2) | |
cat_d_filt = {k[len('decoder') + 1:]: v for k, v in cat_ckpt['state_dict'].items() if k[:len('decoder')] == 'decoder'} | |
cat_decoder.load_state_dict(cat_d_filt, strict=True) | |
cat_decoder.eval() | |
cat_decoder.to(device) | |
dlib_path = hf_hub_download(repo_id="bankholdup/stylegan_petbreeder", filename="shape_predictor_68_face_landmarks.dat") | |
predictor = dlib.shape_predictor(dlib_path) | |
def run_alignment(image_path): | |
aligned_image = align_face(filepath=image_path, predictor=predictor) | |
print("Aligned image has shape: {}".format(aligned_image.size)) | |
return aligned_image | |
def gen_im(ffhq_codes, dog_codes, cat_codes, model_type='ffhq'): | |
if model_type=='ffhq': | |
imgs, _ = ffhq_decoder([ffhq_codes], input_is_latent=True, randomize_noise=False, return_latents=True) | |
elif model_type=='Dog': | |
imgs, _ = dog_decoder([dog_codes], input_is_latent=True, randomize_noise=False, return_latents=True) | |
elif model_type=='Cat': | |
imgs, _ = cat_decoder([cat_codes], input_is_latent=True, randomize_noise=False, return_latents=True) | |
else: | |
imgs, _ = custom_decoder([custom_codes], input_is_latent=True, randomize_noise=False, return_latents=True) | |
return tensor2im(imgs[0]) | |
def set_seed(rd): | |
torch.manual_seed(rd) | |
def inference(img, model): | |
random_seed = round(time.time() * 1000) | |
set_seed(random_seed) | |
try: | |
img.save('out.jpg') | |
try: | |
input_image = run_alignment('out.jpg') | |
except: | |
return 'out.jpg' | |
transformed_image = transform(input_image) | |
ffhq_codes = ffhq_encoder(transformed_image.unsqueeze(0).to(device).float()) | |
ffhq_codes = ffhq_codes + ffhq_latent_avg.repeat(ffhq_codes.shape[0], 1, 1) | |
cat_codes = cat_encoder(transformed_image.unsqueeze(0).to(device).float()) | |
cat_codes = cat_codes + cat_latent_avg.repeat(cat_codes.shape[0], 1, 1) | |
dog_codes = dog_encoder(transformed_image.unsqueeze(0).to(device).float()) | |
dog_codes = dog_codes + dog_latent_avg.repeat(dog_codes.shape[0], 1, 1) | |
npimage = gen_im(ffhq_codes, dog_codes, cat_codes, model) | |
imageio.imwrite('filename.jpeg', npimage) | |
return 'filename.jpeg' | |
except: | |
pass | |
title = "PetBreeder v1.1" | |
description = "Gradio Demo for PetBreeder. Based on [Colab](https://colab.research.google.com/github/tg-bomze/collection-of-notebooks/blob/master/PetBreeder.ipynb) by [@MLArt](https://t.me/MLArt)." | |
gr.Interface(inference, | |
[gr.inputs.Image(type="pil"), | |
gr.inputs.Dropdown(choices=['Cat','Dog'], type='value', default='Cat', label='Model')], | |
gr.outputs.Image(type="file"), | |
title=title, | |
description=description).launch() | |