File size: 2,558 Bytes
bca104a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6245e79
bca104a
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils import data
from torchvision import transforms, utils
from tqdm import tqdm
torch.backends.cudnn.benchmark = True
import copy
from util import *
from PIL import Image

from model import *
import moviepy.video.io.ImageSequenceClip
import scipy
import kornia.augmentation as K

from base64 import b64encode
import gradio as gr
from torchvision import transforms

torch.hub.download_url_to_file('https://i.imgur.com/HiOTPNg.png', 'mona.png')
torch.hub.download_url_to_file('https://i.imgur.com/Cw8HcTN.png', 'painting.png')

device = 'cpu'
latent_dim = 8
n_mlp = 5
num_down = 3

G_A2B = Generator(256, 4, latent_dim, n_mlp, channel_multiplier=1, lr_mlp=.01,n_res=1).to(device).eval()

ensure_checkpoint_exists('GNR_checkpoint.pt')
ckpt = torch.load('GNR_checkpoint.pt', map_location=device)

G_A2B.load_state_dict(ckpt['G_A2B_ema'])

# mean latent
truncation = 1
with torch.no_grad():
    mean_style = G_A2B.mapping(torch.randn([1000, latent_dim]).to(device)).mean(0, keepdim=True)


test_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), inplace=True)
])
plt.rcParams['figure.dpi'] = 200

torch.manual_seed(84986)

num_styles = 1
style = torch.randn([num_styles, latent_dim]).to(device)


def inference(input_im):
    real_A = test_transform(input_im).unsqueeze(0).to(device)

    with torch.no_grad():
        A2B_content, _ = G_A2B.encode(real_A)
        fake_A2B = G_A2B.decode(A2B_content.repeat(num_styles,1,1,1), style)
        std=(0.5, 0.5, 0.5)
        mean=(0.5, 0.5, 0.5)
        z = fake_A2B * torch.tensor(std).view(3, 1, 1)
        z = z + torch.tensor(mean).view(3, 1, 1)
        tensor_to_pil = transforms.ToPILImage(mode='RGB')(z.squeeze())
        return tensor_to_pil

title = "GANsNRoses"
description = "Diverse im2im selfie to anime translation. Read more at the links below."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2106.06561'>GANs N' Roses: Stable, Controllable, Diverse Image to Image Translation (works for videos too!)</a> | <a href='https://github.com/mchong6/GANsNRoses'>Github Repo</a></p>"

gr.Interface(
    inference, 
    [gr.inputs.Image(type="pil", label="Input")], 
    gr.outputs.Image(type="pil", label="Output"),
    title=title,
    description=description,
    article=article,
    examples=[
        ["mona.png"],
        ["painting.png"]
    ]).launch()