File size: 6,722 Bytes
00710e8
88b9835
a0f5d9f
0fc4e06
ad85111
 
75f2ed4
bbcd902
e16d255
75f2ed4
ce6dca2
b33b2b4
fe13422
75f2ed4
57a96a1
e6f200a
814e69a
91d9343
67d69a3
a84e446
 
 
75f2ed4
14fd49f
75f2ed4
814e69a
b33b2b4
814e69a
 
 
 
 
28ac920
 
917ebd2
75f2ed4
 
fc92636
fe13422
 
75f2ed4
00710e8
 
89e4ae0
764d4ab
88b9835
28ac920
b7a47e5
 
764d4ab
980e9a0
764d4ab
b7a47e5
 
814e69a
fa57e87
e21f7c8
57a96a1
89e4ae0
 
88b9835
91d9343
 
 
88b9835
d767ccb
75f2ed4
764d4ab
962b2f7
91d9343
0fc4e06
e21f7c8
 
 
0fc4e06
e21f7c8
0fc4e06
e21f7c8
 
 
 
 
814e69a
e21f7c8
89e4ae0
e21f7c8
 
 
 
0fc4e06
 
e21f7c8
 
 
 
 
 
3b42de6
 
0fc4e06
88b9835
 
14fd49f
0fc4e06
 
3b42de6
0fc4e06
75f2ed4
424869b
 
ab16048
 
 
 
7f82183
ab16048
 
 
7b732c2
3b42de6
ce6dca2
 
 
 
 
3b42de6
ce6dca2
 
 
 
17b3937
 
fa57e87
ce6dca2
fa57e87
ce6dca2
22696bb
ce6dca2
0fc4e06
 
e21f7c8
 
0fc4e06
ce6dca2
0fc4e06
 
 
 
 
ce6dca2
 
0fc4e06
 
ce6dca2
ab16048
ce6dca2
 
0fc4e06
3b42de6
ce6dca2
 
0fc4e06
 
ce6dca2
0fc4e06
 
ce6dca2
de50edd
 
 
fc92636
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import os
import time
from datetime import datetime, timezone, timedelta
from concurrent.futures import ThreadPoolExecutor

import spaces
import torch
import torchvision.models as models
import numpy as np
import gradio as gr
from gradio_imageslider import ImageSlider
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download

from utils import preprocess_img, preprocess_img_from_path, postprocess_img
from vgg.vgg19 import VGG_19
from u2net.model import U2Net
from inference import inference

if torch.cuda.is_available(): device = 'cuda'
elif torch.backends.mps.is_available(): device = 'mps'
else: device = 'cpu'
print('DEVICE:', device)
if device == 'cuda': print('CUDA DEVICE:', torch.cuda.get_device_name())

def load_model_without_module(model, model_path):
    state_dict = load_file(model_path, device=device)
    new_state_dict = {}
    for k, v in state_dict.items():
        name = k[7:] if k.startswith('module.') else k
        new_state_dict[name] = v
    model.load_state_dict(new_state_dict)
   
# load models 
model = VGG_19().to(device).eval()
for param in model.parameters():
    param.requires_grad = False
sod_model = U2Net().to(device).eval()
local_model_path = hf_hub_download(repo_id='jamino30/u2net-saliency', filename='u2net-duts-msra.safetensors')
load_model_without_module(sod_model, local_model_path)

style_files = os.listdir('./style_images')
style_options = {' '.join(style_file.split('.')[0].split('_')): f'./style_images/{style_file}' for style_file in style_files}
lrs = np.logspace(np.log10(0.001), np.log10(0.1), 10).tolist()
img_size = 512

# store style(s) features
cached_style_features = {}
for style_name, style_img_path in style_options.items():
    style_img = preprocess_img_from_path(style_img_path, img_size)[0].to(device)
    with torch.no_grad():
        style_features = model(style_img)
    cached_style_features[style_name] = style_features 

@spaces.GPU(duration=30)
def run(content_image, style_name, style_strength=10):
    yield [None] * 3
    content_img, original_size = preprocess_img(content_image, img_size)
    content_img_normalized, _ = preprocess_img(content_image, img_size, normalize=True)
    content_img, content_img_normalized = content_img.to(device), content_img_normalized.to(device)
    
    print('-'*15)
    print('DATETIME:', datetime.now(timezone.utc) - timedelta(hours=4)) # est
    print('STYLE:', style_name)
    print('CONTENT IMG SIZE:', original_size)
    print('STYLE STRENGTH:', style_strength, f'(lr={lrs[style_strength-1]:.3f})')

    style_features = cached_style_features[style_name]
    
    st = time.time()
    
    if device == 'cuda':
        stream_all = torch.cuda.Stream()
        stream_bg = torch.cuda.Stream()

    def run_inference_cuda(apply_to_background, stream):
        with torch.cuda.stream(stream):
            return run_inference(apply_to_background)
        
    def run_inference(apply_to_background):
        return inference(
            model=model,
            sod_model=sod_model,
            content_image=content_img,
            content_image_norm=content_img_normalized,
            style_features=style_features,
            lr=lrs[style_strength-1],
            apply_to_background=apply_to_background
        )

    with ThreadPoolExecutor() as executor:
        if device == 'cuda':
            future_all = executor.submit(run_inference_cuda, False, stream_all)
            future_bg = executor.submit(run_inference_cuda, True, stream_bg)
        else:
            future_all = executor.submit(run_inference, False)
            future_bg = executor.submit(run_inference, True)
        generated_img_all = future_all.result()
        generated_img_bg = future_bg.result()

    et = time.time()
    print('TIME TAKEN:', et-st)
    
    yield (
        (content_image, postprocess_img(generated_img_all, original_size)),
        (content_image, postprocess_img(generated_img_bg, original_size))
    )

def set_slider(value):
    return gr.update(value=value)

css = """
#container {
    margin: 0 auto;
    max-width: 1200px;
}
"""

with gr.Blocks(css=css) as demo:
    gr.HTML("<h1 style='text-align: center; padding: 10px'>🖼️ Neural Style Transfer w/ Salient Region Preservation")
    with gr.Row(elem_id='container'):
        with gr.Column():
            content_image = gr.Image(label='Content', type='pil', sources=['upload', 'webcam', 'clipboard'], format='jpg', show_download_button=False)
            style_dropdown = gr.Radio(choices=list(style_options.keys()), label='Style', value='Starry Night', type='value')
            with gr.Group():
                style_strength_slider = gr.Slider(label='Style Strength', minimum=1, maximum=10, step=1, value=10, info='Higher values add artistic flair, lower values add a realistic feel.')
            submit_button = gr.Button('Submit', variant='primary')
            
            examples = gr.Examples(
                examples=[
                    ['./content_images/Surfer.jpg', 'Starry Night'],
                    ['./content_images/GoldenRetriever.jpg', 'Great Wave'],
                    ['./content_images/CameraGirl.jpg', 'Bokeh']
                ],
                inputs=[content_image, style_dropdown]
            )

        with gr.Column():
            output_image_all = ImageSlider(position=0.15, label='Styled Image', type='pil', interactive=False, show_download_button=False)
            download_button_1 = gr.DownloadButton(label='Download Styled Image', visible=False)
            with gr.Group():
                output_image_background = ImageSlider(position=0.15, label='Styled Background', type='pil', interactive=False, show_download_button=False)
            download_button_2 = gr.DownloadButton(label='Download Styled Background', visible=False)

    def save_image(img_tuple1, img_tuple2):
        filename1, filename2 = 'generated-all.jpg', 'generated-bg.jpg'
        img_tuple1[1].save(filename1)
        img_tuple2[1].save(filename2)
        return filename1, filename2
    
    submit_button.click(
        fn=lambda: [gr.update(visible=False) for _ in range(2)],
        outputs=[download_button_1, download_button_2]
    )
        
    submit_button.click(
        fn=run, 
        inputs=[content_image, style_dropdown, style_strength_slider], 
        outputs=[output_image_all, output_image_background]
    ).then(
        fn=save_image,
        inputs=[output_image_all, output_image_background],
        outputs=[download_button_1, download_button_2]
    ).then(
        fn=lambda: [gr.update(visible=True) for _ in range(2)],
        outputs=[download_button_1, download_button_2]
    )

demo.queue = False
demo.config['queue'] = False
demo.launch(show_api=False)