File size: 2,690 Bytes
146d3ac
c39e45e
 
14567ed
 
146d3ac
 
dffcfb5
f60da0f
 
146d3ac
ab42d96
 
146d3ac
dffcfb5
146d3ac
 
 
 
dffcfb5
 
 
 
146d3ac
 
72c2d65
 
 
 
 
 
 
 
146d3ac
 
 
 
 
72c2d65
 
 
 
146d3ac
 
 
 
 
 
 
 
03ef60c
14567ed
 
03ef60c
 
 
 
 
146d3ac
03ef60c
146d3ac
 
 
 
8e8807d
146d3ac
 
 
70ce17b
03ef60c
146d3ac
 
 
0d7a15e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os
os.system("pip freeze")

import torch
import PIL
import gradio as gr
import torch
from utils import align_face
from torchvision import transforms
from huggingface_hub import hf_hub_download

device = "cuda:0" if torch.cuda.is_available() else "cpu"

image_size = 512
transform_size = 1024

means = [0.5, 0.5, 0.5]
stds = [0.5, 0.5, 0.5]

img_transforms = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(means, stds)])
 
model_path = hf_hub_download(repo_id="jjeamin/ArcaneStyleTransfer", filename="pytorch_model.bin")

if 'cuda' in device:
    style_transfer = torch.jit.load(model_path).eval().cuda().half()
    t_stds = torch.tensor(stds).cuda().half()[:,None,None]
    t_means = torch.tensor(means).cuda().half()[:,None,None]
else:
    style_transfer = torch.jit.load(model_path).eval().cpu()
    t_stds = torch.tensor(stds).cpu()[:,None,None]
    t_means = torch.tensor(means).cpu()[:,None,None]

def tensor2im(var):
     return var.mul(t_stds).add(t_means).mul(255.).clamp(0,255).permute(1,2,0)

def proc_pil_img(input_image):
    if 'cuda' in device: 
        transformed_image = img_transforms(input_image)[None,...].cuda().half()
    else:
        transformed_image = img_transforms(input_image)[None,...].cpu()
            
    with torch.no_grad():
        result_image = style_transfer(transformed_image)[0]
        output_image = tensor2im(result_image)
        output_image = output_image.detach().cpu().numpy().astype('uint8')
        output_image = PIL.Image.fromarray(output_image)
    return output_image

def process(im, is_align):
    im = PIL.ImageOps.exif_transpose(im)
    
    if is_align == 'True':
        im = align_face(im, output_size=image_size, transform_size=transform_size)
    else: 
        pass
        
    res = proc_pil_img(im)
    
    return res
        
gr.Interface(
    process, 
    inputs=[gr.inputs.Image(type="pil", label="Input", shape=(image_size, image_size)), gr.inputs.Radio(['True','False'], type="value", default='True', label='face align')],
    outputs=gr.outputs.Image(type="pil", label="Output"),
    title="Arcane Style Transfer",
    description="Gradio demo for Arcane Style Transfer",
    article = "<p style='text-align: center'><a href='https://github.com/jjeamin/anime_style_transfer_pytorch' target='_blank'>Github Repo by jjeamin</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=jjeamin_arcane_st' alt='visitor badge'></center></p>",
    examples=[['billie.png', 'True'], ['gongyoo.jpeg', 'True'], ['IU.png', 'True'], ['elon.png', 'True']],
    enable_queue=True,
    allow_flagging=False,
    allow_screenshot=False
    ).launch(enable_queue=True)