Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,744 Bytes
00710e8 88b9835 a0f5d9f ad85111 75f2ed4 e16d255 75f2ed4 fe13422 75f2ed4 a1732e3 e6f200a 814e69a 91d9343 67d69a3 a84e446 a3814f8 28ac920 917ebd2 75f2ed4 fc92636 a1732e3 00710e8 a3814f8 764d4ab 88b9835 a1732e3 b7a47e5 a1732e3 d35a711 57a96a1 89e4ae0 d35a711 88b9835 a3814f8 962b2f7 91d9343 d35a711 a1732e3 14fd49f d35a711 75f2ed4 ab16048 7f82183 ab16048 7b732c2 3b42de6 ce6dca2 a3814f8 d35a711 a3814f8 ce6dca2 17b3937 fa57e87 ce6dca2 fa57e87 ce6dca2 22696bb ce6dca2 d35a711 ce6dca2 d35a711 ce6dca2 d35a711 ce6dca2 ab16048 ce6dca2 a1732e3 d35a711 ce6dca2 d35a711 ce6dca2 d35a711 ce6dca2 de50edd fc92636 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import os
import time
from datetime import datetime, timezone, timedelta
import spaces
import torch
import numpy as np
import gradio as gr
from huggingface_hub import hf_hub_download
from utils import preprocess_img, postprocess_img, load_model_without_module
from vgg.vgg19 import VGG_19
from u2net.model import U2Net
from inference import inference
if torch.cuda.is_available(): device = 'cuda'
elif torch.backends.mps.is_available(): device = 'mps'
else: device = 'cpu'
print('Device:', device)
if device == 'cuda': print('Name:', torch.cuda.get_device_name())
# load models
model = VGG_19().to(device).eval()
for param in model.parameters():
param.requires_grad = False
sod_model = U2Net().to(device).eval()
load_model_without_module(
sod_model,
hf_hub_download(repo_id='jamino30/u2net-saliency', filename='u2net-duts-msra.safetensors'),
device=device
)
style_files = os.listdir('./style_images')
style_options = {
'Starry Night': './style_images/Starry_Night.jpg',
'Starry Night (v2)': './style_images/Starry_Night_v2.jpg',
'Scream': './style_images/Scream.jpg',
'Great Wave': './style_images/Great_Wave.jpg',
'Oil Painting': './style_images/Oil_Painting.jpg',
'Watercolor': './style_images/Watercolor.jpg',
'Mosaic': './style_images/Mosaic.jpg',
'Lego Bricks': './style_images/Lego_Bricks.jpg',
'Bokeh': './style_images/Bokeh.jpg',
}
lrs = np.linspace(0.015, 0.075, 3).tolist()
img_size = 512
cached_style_features = {
style_name: model(preprocess_img(style_img_path, img_size)[0].to(device))
for style_name, style_img_path in style_options.items()
}
@spaces.GPU(duration=15)
def run(content_image, style_name, style_strength=len(lrs), apply_to_background=False):
yield None
content_img, original_size = preprocess_img(content_image, img_size)
content_img_normalized, _ = preprocess_img(content_image, img_size, normalize=True)
content_img, content_img_normalized = content_img.to(device), content_img_normalized.to(device)
style_features = cached_style_features[style_name]
print('-'*30)
print(datetime.now(timezone.utc) - timedelta(hours=5)) # EST
st = time.time()
generated_img = inference(
model=model,
sod_model=sod_model,
content_image=content_img,
content_image_norm=content_img_normalized,
style_features=style_features,
lr=lrs[style_strength-1],
apply_to_background=apply_to_background,
)
print(f'{time.time()-st:.2f}s')
yield postprocess_img(generated_img, original_size)
css = """
#container {
margin: 0 auto;
max-width: 1200px;
}
"""
with gr.Blocks(css=css) as demo:
gr.HTML("<h1 style='text-align: center; padding: 10px'>🖼️ Neural Style Transfer w/ Salient Region Preservation")
with gr.Row(elem_id='container'):
with gr.Column():
with gr.Group():
content_image = gr.Image(label='Content', type='pil', sources=['upload', 'webcam', 'clipboard'], format='jpg', show_download_button=False)
with gr.Group():
style_dropdown = gr.Radio(choices=list(style_options.keys()), label='Style', value='Starry Night', type='value')
style_strength_slider = gr.Slider(label='Style Strength', minimum=1, maximum=len(lrs), step=1, value=len(lrs))
apply_to_background_checkbox = gr.Checkbox(label='Apply style transfer exclusively to the background', value=False)
submit_button = gr.Button('Submit', variant='primary')
examples = gr.Examples(
examples=[
['./content_images/GoldenRetriever.jpg', 'Great Wave'],
['./content_images/CameraGirl.jpg', 'Bokeh']
],
inputs=[content_image, style_dropdown]
)
with gr.Column():
output_image = gr.Image(label='Output', type='pil', interactive=False, show_download_button=False)
download_button = gr.DownloadButton(label='Download Image', visible=False)
def save_image(img):
filename = 'generated.jpg'
img.save(filename)
return filename
submit_button.click(
fn=lambda: gr.update(visible=False),
outputs=download_button
)
submit_button.click(
fn=run,
inputs=[content_image, style_dropdown, style_strength_slider, apply_to_background_checkbox],
outputs=output_image
).then(
fn=save_image,
inputs=output_image,
outputs=download_button
).then(
fn=lambda: gr.update(visible=True),
outputs=download_button
)
demo.queue = False
demo.config['queue'] = False
demo.launch(show_api=False)
|