Spaces:
Runtime error
Runtime error
File size: 2,558 Bytes
bca104a 6245e79 bca104a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import os
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils import data
from torchvision import transforms, utils
from tqdm import tqdm
torch.backends.cudnn.benchmark = True
import copy
from util import *
from PIL import Image
from model import *
import moviepy.video.io.ImageSequenceClip
import scipy
import kornia.augmentation as K
from base64 import b64encode
import gradio as gr
from torchvision import transforms
torch.hub.download_url_to_file('https://i.imgur.com/HiOTPNg.png', 'mona.png')
torch.hub.download_url_to_file('https://i.imgur.com/Cw8HcTN.png', 'painting.png')
device = 'cpu'
latent_dim = 8
n_mlp = 5
num_down = 3
G_A2B = Generator(256, 4, latent_dim, n_mlp, channel_multiplier=1, lr_mlp=.01,n_res=1).to(device).eval()
ensure_checkpoint_exists('GNR_checkpoint.pt')
ckpt = torch.load('GNR_checkpoint.pt', map_location=device)
G_A2B.load_state_dict(ckpt['G_A2B_ema'])
# mean latent
truncation = 1
with torch.no_grad():
mean_style = G_A2B.mapping(torch.randn([1000, latent_dim]).to(device)).mean(0, keepdim=True)
test_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), inplace=True)
])
plt.rcParams['figure.dpi'] = 200
torch.manual_seed(84986)
num_styles = 1
style = torch.randn([num_styles, latent_dim]).to(device)
def inference(input_im):
real_A = test_transform(input_im).unsqueeze(0).to(device)
with torch.no_grad():
A2B_content, _ = G_A2B.encode(real_A)
fake_A2B = G_A2B.decode(A2B_content.repeat(num_styles,1,1,1), style)
std=(0.5, 0.5, 0.5)
mean=(0.5, 0.5, 0.5)
z = fake_A2B * torch.tensor(std).view(3, 1, 1)
z = z + torch.tensor(mean).view(3, 1, 1)
tensor_to_pil = transforms.ToPILImage(mode='RGB')(z.squeeze())
return tensor_to_pil
title = "GANsNRoses"
description = "Diverse im2im selfie to anime translation. Read more at the links below."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2106.06561'>GANs N' Roses: Stable, Controllable, Diverse Image to Image Translation (works for videos too!)</a> | <a href='https://github.com/mchong6/GANsNRoses'>Github Repo</a></p>"
gr.Interface(
inference,
[gr.inputs.Image(type="pil", label="Input")],
gr.outputs.Image(type="pil", label="Output"),
title=title,
description=description,
article=article,
examples=[
["mona.png"],
["painting.png"]
]).launch()
|