|
import os |
|
import streamlit as st |
|
import gdown |
|
from packaging.version import Version |
|
|
|
from infer_func import convert |
|
|
|
ROOT = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
EXAMPLES = { |
|
'content': { |
|
'Brad Pitt': ROOT + '/examples/content/brad_pitt.jpg' |
|
}, |
|
'style': { |
|
'Flower of Life': ROOT + '/examples/style/flower_of_life.jpg' |
|
} |
|
} |
|
|
|
VGG_WEIGHT_URL = 'https://drive.google.com/uc?id=1UcSl-Zn3byEmn15NIPXMf9zaGCKc2gfx' |
|
DECODER_WEIGHT_URL = 'https://drive.google.com/uc?id=18JpLtMOapA-vwBz-LRomyTl24A9GwhTF' |
|
|
|
VGG_WEIGHT_FILENAME = ROOT + '/vgg.pth' |
|
DECODER_WEIGHT_FILENAME = ROOT + '/decoder.pth' |
|
|
|
|
|
@st.cache |
|
def download_models(): |
|
with st.spinner(text="Downloading VGG weights..."): |
|
gdown.download(VGG_WEIGHT_URL, output=VGG_WEIGHT_FILENAME) |
|
with st.spinner(text="Downloading Decoder weights..."): |
|
gdown.download(DECODER_WEIGHT_URL, output=DECODER_WEIGHT_FILENAME) |
|
|
|
|
|
def image_getter(image_kind): |
|
|
|
image = None |
|
|
|
options = ['Use Example Image', 'Upload Image'] |
|
|
|
if Version(st.__version__) >= Version('1.4.0'): |
|
options.append('Open Camera') |
|
|
|
option = st.selectbox( |
|
'Choose Image', |
|
options, key=image_kind) |
|
|
|
if option == 'Use Example Image': |
|
image_key = st.selectbox( |
|
'Choose from examples', |
|
EXAMPLES[image_kind], key=image_kind) |
|
image = EXAMPLES[image_kind][image_key] |
|
|
|
elif option == 'Upload Image': |
|
image = st.file_uploader("Upload an image", type=['png', 'jpg', 'PNG', 'JPG', 'JPEG'], key=image_kind) |
|
elif option == 'Open Camera': |
|
image = st.camera_input('', key=image_kind) |
|
|
|
return image |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
st.set_page_config(layout="wide") |
|
st.header('Adaptive Instance Normalization demo based on ' |
|
'[2022-AdaIN-pytorch](https://github.com/media-comp/2022-AdaIN-pytorch)') |
|
|
|
download_models() |
|
|
|
col1, col2, col3 = st.columns((3, 4, 4)) |
|
with col1: |
|
st.subheader('Content Image') |
|
content = image_getter('content') |
|
st.subheader('Style Image') |
|
style = image_getter('style') |
|
with col2: |
|
img1 = content if content is not None else 'examples/img.png' |
|
img2 = style if style is not None else 'examples/img.png' |
|
if img1 is not None: |
|
st.image(img1, width=None, caption='Content Image') |
|
if img2 is not None: |
|
st.image(img2, width=None, caption='Style Image') |
|
|
|
with col3: |
|
color_control = st.checkbox('Preserve content image color') |
|
alpha = st.slider('Strength of style transfer', 0.0, 1.0, 1.0, 0.01) |
|
process = st.button('Stylize') |
|
|
|
output_image = 'output.png' |
|
if content is not None and style is not None and process: |
|
print(content, style) |
|
with st.spinner('Processing...'): |
|
output_image = convert(content, style, VGG_WEIGHT_FILENAME, DECODER_WEIGHT_FILENAME, alpha, color_control) |
|
|
|
if os.path.exists(output_image): |
|
with col3: |
|
st.image(output_image, width=None, caption='Stylized Image') |
|
|
|
|