GANsNRoses / gradiodemo.py
aliabd
copied all files from repo
bca104a
raw history blame
No virus
2.62 kB
import os
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils import data
from torchvision import transforms, utils
from tqdm import tqdm
torch.backends.cudnn.benchmark = True
import copy
from util import *
from PIL import Image
from model import *
import moviepy.video.io.ImageSequenceClip
import scipy
import kornia.augmentation as K
from base64 import b64encode
import gradio as gr
from torchvision import transforms
torch.hub.download_url_to_file('https://i.imgur.com/HiOTPNg.png', 'mona.png')
torch.hub.download_url_to_file('https://i.imgur.com/Cw8HcTN.png', 'painting.png')
device = 'cpu'
latent_dim = 8
n_mlp = 5
num_down = 3
G_A2B = Generator(256, 4, latent_dim, n_mlp, channel_multiplier=1, lr_mlp=.01,n_res=1).to(device).eval()
ensure_checkpoint_exists('GNR_checkpoint.pt')
ckpt = torch.load('GNR_checkpoint.pt', map_location=device)
G_A2B.load_state_dict(ckpt['G_A2B_ema'])
# mean latent
truncation = 1
with torch.no_grad():
mean_style = G_A2B.mapping(torch.randn([1000, latent_dim]).to(device)).mean(0, keepdim=True)
test_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), inplace=True)
])
plt.rcParams['figure.dpi'] = 200
torch.manual_seed(84986)
num_styles = 1
style = torch.randn([num_styles, latent_dim]).to(device)
def inference(input_im):
real_A = test_transform(input_im).unsqueeze(0).to(device)
with torch.no_grad():
A2B_content, _ = G_A2B.encode(real_A)
fake_A2B = G_A2B.decode(A2B_content.repeat(num_styles,1,1,1), style)
std=(0.5, 0.5, 0.5)
mean=(0.5, 0.5, 0.5)
z = fake_A2B * torch.tensor(std).view(3, 1, 1)
z = z + torch.tensor(mean).view(3, 1, 1)
tensor_to_pil = transforms.ToPILImage(mode='RGB')(z.squeeze())
return tensor_to_pil
title = "GANsNRoses"
description = "demo for GANsNRoses. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2106.06561'>GANs N' Roses: Stable, Controllable, Diverse Image to Image Translation (works for videos too!)</a> | <a href='https://github.com/mchong6/GANsNRoses'>Github Repo</a></p>"
gr.Interface(
inference,
[gr.inputs.Image(type="pil", label="Input")],
gr.outputs.Image(type="pil", label="Output"),
title=title,
description=description,
article=article,
examples=[
["mona.png"],
["painting.png"]
]).launch()