File size: 1,652 Bytes
ef01fd7
 
 
 
 
000a894
ef01fd7
 
 
 
 
 
 
000a894
 
 
 
ef01fd7
 
 
 
 
 
 
 
 
c9fd122
 
ef01fd7
 
 
 
 
 
 
 
 
000a894
ef01fd7
 
 
 
 
 
 
 
 
c9fd122
 
1c5ae49
ef01fd7
 
 
 
 
 
2f675f6
ef01fd7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from __future__ import print_function
import torch
import process_stylization
from photo_wct import PhotoWCT
import gradio as gr
from datetime import datetime

# Load model
model_path = './models/photo_wct.pth'
p_wct = PhotoWCT()
p_wct.load_state_dict(torch.load(model_path))

def run(content_img, style_img, cuda, post_processing, fast):
    now = datetime.now() 
    dt_string = now.strftime("%d/%m/%Y %H:%M:%S")
    print("[TimeStamp] {}".format(dt_string))
    
    if fast == 0:
        from photo_gif import GIFSmoothing
        p_pro = GIFSmoothing(r=35, eps=0.001)
    else:
        from photo_smooth import Propagator
        p_pro = Propagator()
    
    if cuda:
        p_wct.cuda(0)
    else:
        p_wct.to('cpu')
    
    output_img = process_stylization.stylization_gradio(
        stylization_module=p_wct,
        smoothing_module=p_pro,
        content_image=content_img,
        style_image=style_img,
        cuda=cuda,
        post_processing=post_processing
    )
    
    return output_img

if __name__ == '__main__':

    style = gr.Interface(
        fn=run, 
        inputs=[
            gr.Image(label='Content Image'),
            gr.Image(label='Stylize Image'),
            gr.Checkbox(value=True, label='Use CUDA'),
            gr.Checkbox(value=True, label='Post Processing'),
            gr.Radio(choices=["Guided Image Filtering (Fast)", "Photorealisitic Smoothing (Slow)"], value="Guided Image Filtering (Fast)", type="index", label="Algorithm", interactive=False),
        ], 
        outputs=[gr.Image(
            type="pil",
            label="Result"),
        ]    
    )
    style.queue()
    style.launch()