File size: 4,744 Bytes
00710e8
88b9835
a0f5d9f
ad85111
 
75f2ed4
e16d255
75f2ed4
fe13422
75f2ed4
a1732e3
e6f200a
814e69a
91d9343
67d69a3
a84e446
 
 
a3814f8
 
28ac920
 
917ebd2
75f2ed4
 
fc92636
a1732e3
 
 
 
 
 
00710e8
a3814f8
 
 
 
 
 
 
 
 
 
 
 
764d4ab
88b9835
a1732e3
 
 
 
b7a47e5
a1732e3
 
d35a711
57a96a1
89e4ae0
 
d35a711
88b9835
a3814f8
 
962b2f7
91d9343
d35a711
 
 
 
 
 
 
 
 
a1732e3
14fd49f
d35a711
75f2ed4
ab16048
 
 
7f82183
ab16048
 
 
7b732c2
3b42de6
ce6dca2
 
 
a3814f8
d35a711
a3814f8
 
 
ce6dca2
 
 
 
17b3937
fa57e87
ce6dca2
fa57e87
ce6dca2
22696bb
ce6dca2
d35a711
 
ce6dca2
d35a711
 
 
 
ce6dca2
 
d35a711
 
ce6dca2
ab16048
ce6dca2
 
a1732e3
d35a711
ce6dca2
 
d35a711
 
ce6dca2
d35a711
 
ce6dca2
de50edd
 
 
fc92636
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import os
import time
from datetime import datetime, timezone, timedelta

import spaces
import torch
import numpy as np
import gradio as gr
from huggingface_hub import hf_hub_download

from utils import preprocess_img, postprocess_img, load_model_without_module
from vgg.vgg19 import VGG_19
from u2net.model import U2Net
from inference import inference

if torch.cuda.is_available(): device = 'cuda'
elif torch.backends.mps.is_available(): device = 'mps'
else: device = 'cpu'
print('Device:', device)
if device == 'cuda': print('Name:', torch.cuda.get_device_name())
   
# load models 
model = VGG_19().to(device).eval()
for param in model.parameters():
    param.requires_grad = False
sod_model = U2Net().to(device).eval()
load_model_without_module(
    sod_model, 
    hf_hub_download(repo_id='jamino30/u2net-saliency', filename='u2net-duts-msra.safetensors'),
    device=device
)

style_files = os.listdir('./style_images')
style_options = {
    'Starry Night': './style_images/Starry_Night.jpg',
    'Starry Night (v2)': './style_images/Starry_Night_v2.jpg',
    'Scream': './style_images/Scream.jpg',
    'Great Wave': './style_images/Great_Wave.jpg',
    'Oil Painting': './style_images/Oil_Painting.jpg',
    'Watercolor': './style_images/Watercolor.jpg',
    'Mosaic': './style_images/Mosaic.jpg',
    'Lego Bricks': './style_images/Lego_Bricks.jpg',
    'Bokeh': './style_images/Bokeh.jpg',
}
lrs = np.linspace(0.015, 0.075, 3).tolist()
img_size = 512

cached_style_features = {
    style_name: model(preprocess_img(style_img_path, img_size)[0].to(device))
    for style_name, style_img_path in style_options.items()
}

@spaces.GPU(duration=15)
def run(content_image, style_name, style_strength=len(lrs), apply_to_background=False):
    yield None
    content_img, original_size = preprocess_img(content_image, img_size)
    content_img_normalized, _ = preprocess_img(content_image, img_size, normalize=True)
    content_img, content_img_normalized = content_img.to(device), content_img_normalized.to(device)
    style_features = cached_style_features[style_name]
    
    print('-'*30)
    print(datetime.now(timezone.utc) - timedelta(hours=5)) # EST
    
    st = time.time()
    generated_img = inference(
        model=model,
        sod_model=sod_model,
        content_image=content_img,
        content_image_norm=content_img_normalized,
        style_features=style_features,
        lr=lrs[style_strength-1],
        apply_to_background=apply_to_background,
    )
    print(f'{time.time()-st:.2f}s')
    
    yield postprocess_img(generated_img, original_size)

css = """
#container {
    margin: 0 auto;
    max-width: 1200px;
}
"""

with gr.Blocks(css=css) as demo:
    gr.HTML("<h1 style='text-align: center; padding: 10px'>🖼️ Neural Style Transfer w/ Salient Region Preservation")
    with gr.Row(elem_id='container'):
        with gr.Column():
            with gr.Group():
                content_image = gr.Image(label='Content', type='pil', sources=['upload', 'webcam', 'clipboard'], format='jpg', show_download_button=False)
            with gr.Group():
                style_dropdown = gr.Radio(choices=list(style_options.keys()), label='Style', value='Starry Night', type='value')
                style_strength_slider = gr.Slider(label='Style Strength', minimum=1, maximum=len(lrs), step=1, value=len(lrs))
                apply_to_background_checkbox = gr.Checkbox(label='Apply style transfer exclusively to the background', value=False)
            submit_button = gr.Button('Submit', variant='primary')
            
            examples = gr.Examples(
                examples=[
                    ['./content_images/GoldenRetriever.jpg', 'Great Wave'],
                    ['./content_images/CameraGirl.jpg', 'Bokeh']
                ],
                inputs=[content_image, style_dropdown]
            )

        with gr.Column():
            output_image = gr.Image(label='Output', type='pil', interactive=False, show_download_button=False)
            download_button = gr.DownloadButton(label='Download Image', visible=False)

    def save_image(img):
        filename = 'generated.jpg'
        img.save(filename)
        return filename
    
    submit_button.click(
        fn=lambda: gr.update(visible=False),
        outputs=download_button
    )
        
    submit_button.click(
        fn=run, 
        inputs=[content_image, style_dropdown, style_strength_slider, apply_to_background_checkbox], 
        outputs=output_image
    ).then(
        fn=save_image,
        inputs=output_image,
        outputs=download_button
    ).then(
        fn=lambda: gr.update(visible=True),
        outputs=download_button
    )

demo.queue = False
demo.config['queue'] = False
demo.launch(show_api=False)