SofaStyler / app.py
Sophie98
change to streamlit
ad1ac8f
# Import libraries
import numpy as np
import streamlit as st
from PIL import ExifTags, Image
from Segmentation.segmentation import get_mask, replace_sofa
from StyleTransfer.styleTransfer import (
StyleFAST,
StyleTransformer,
styleProjection,
)
PAGE_CONFIG = {
"page_title": "SofaStyler.io",
"page_icon": ":art:",
"layout": "centered",
}
st.set_page_config(**PAGE_CONFIG)
def fix_orient(img: Image.Image) -> Image.Image:
"""
This function fix the orientation of input images.
This is especially usefull in the context of images from a mobile phone.
Parameters:
img = input image
Return:
img = img with correct orientation
"""
flag = False
for orientation in ExifTags.TAGS.keys():
if ExifTags.TAGS[orientation] == "Orientation":
flag = True
break
info = img.getexif()
if len(info) & flag:
info = dict(info.items())
if orientation in info.keys():
orientation = info[orientation]
if (orientation == 1) | (orientation == 2):
img = img
if (orientation == 3) | (orientation == 4):
img = img.rotate(180, expand=True)
if (orientation == 5) | (orientation == 6):
img = img.rotate(270, expand=True)
if (orientation == 7) | (orientation == 8):
img = img.rotate(90, expand=True)
return img
def resize_sofa(img: Image.Image) -> Image.Image:
"""
This function adds padding to make the original image square and 640by640.
It also returns the original ratio of the image, such that it can be reverted later.
Parameters:
img = original image
Return:
img_square = squared image
box = parameters to later crop the image to it original ratio
"""
width, height = img.size
idx = np.argmin([width, height])
newsize = (640, 640) # parameters from test script
if idx == 0:
img_square = Image.new(img.mode, (height, height), (255, 255, 255))
img_square.paste(img, ((height - width) // 2, 0))
box = (
newsize[0] * (1 - width / height) // 2,
0,
newsize[0] - newsize[0] * (1 - width / height) // 2,
newsize[1],
)
else:
img_square = Image.new(img.mode, (width, width), (255, 255, 255))
img_square.paste(img, (0, (width - height) // 2))
box = (
0,
newsize[1] * (1 - height / width) // 2,
newsize[0],
newsize[1] - newsize[1] * (1 - height / width) // 2,
)
img_square = img_square.resize(newsize)
return img_square, box
def resize_style(img: Image.Image) -> Image.Image:
"""
This function generates a zoomed out version of
the style image and resizes it to a 640by640 square.
Parameters:
img = image containing the style/pattern
Return:
dst = a zoomed-out and resized version of the pattern
"""
width, height = img.size
idx = np.argmin([width, height])
# Makes the image square by cropping
if idx == 0:
top = (height - width) // 2
bottom = height - (height - width) // 2
left = 0
right = width
else:
left = (width - height) // 2
right = width - (width - height) // 2
top = 0
bottom = height
newsize = (640, 640) # parameters from test script
img = img.crop((left, top, right, bottom))
# Constructs a zoomed-out version
copies = 8
resize = (newsize[0] // copies, newsize[1] // copies)
img_zoomed_out = Image.new("RGB", (resize[0] * copies, resize[1] * copies))
img = img.resize((resize))
for row in range(copies):
img = img.transpose(Image.FLIP_LEFT_RIGHT)
for column in range(copies):
img = img.transpose(Image.FLIP_TOP_BOTTOM)
img_zoomed_out.paste(img, (resize[0] * row, resize[1] * column))
img_zoomed_out = img_zoomed_out.resize((newsize))
return img_zoomed_out
image = Image.open("figures/logo.png") # Brand logo image (optional)
options = [
"Style Transformer",
"StyleFAST",
"Style Projection",
]
# Create two columns with different width
col1, col2 = st.columns([0.8, 0.2])
with col1: # To display the header text using css style
st.markdown(
"""
<style>
@import url('https://fonts.googleapis.com/css2?family=Arvo&display=swap');
</style>
""",
unsafe_allow_html=True,
)
st.markdown(
"""
<style>
html, body, [class*="css"] {
font-family: 'Arvo';
}
</style>
""",
unsafe_allow_html=True,
)
st.markdown(
""" <style> .font {
font-size:30px ; font-family: 'Arvo'; color: #04b188;
src:url("https://fonts.googleapis.com/css2?family=Arvo&display=swap");}
</style> """,
unsafe_allow_html=True,
) # } </style> """, unsafe_allow_html=True)
st.markdown(
'<p class="font">Upload your photos here...</p>', unsafe_allow_html=True
)
with col2: # To display brand logo
st.image(image, width=150)
# Add a header and expander in side bar
st.sidebar.markdown('<center class="font">πŸ›‹ </center>', unsafe_allow_html=True)
st.sidebar.markdown(
'<center class="font">A sofastyler App</center>', unsafe_allow_html=True
)
st.sidebar.markdown("")
with st.sidebar.text("About the App"):
st.write(
"""
Customize your sofa to your wildest dreams πŸ’­!\
\nProvide a picture of your sofa, a desired pattern and\
choose one of the algorithms below.
\nOr just look at an example.
"""
)
st.sidebar.title("")
with st.sidebar.expander("References"):
st.write(
"[1. The data that was used to train the segmentation model.]"
+ "(https://tianchi.aliyun.com/specials/promotion/alibaba-3d-future)"
+ "\n\n"
+ "[2. Github repository used to train a segmentation model with transfer "
+ "learning.]"
+ "(https://github.com/qubvel/segmentation_models)"
+ "\n\n"
+ "[3. The github repository that is used for the style transformer.]"
+ "(https://github.com/diyiiyiii/StyTR-2)"
+ "\n\n"
+ "[4. A tensorflow model for fast arbitrary image style transfer.]"
+ "(https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2)"
+ "\n\n"
+ "[5. A paddleHub model for parameter free style transfer.]"
+ "(https://github.com/PaddlePaddle/PaddleHub/tree/release/v2.2/modules/"
+ "image/Image_gan/style_transfer/stylepro_artistic)"
)
# Add file uploader to allow users to upload photos
uploaded_content = st.file_uploader(
label="Image with sofa", type=["jpg", "png", "jpeg"]
)
uploaded_style = st.file_uploader(
label="Image with pattern", type=["jpg", "png", "jpeg"]
)
# Example section
checkbox = st.checkbox("Show example")
if checkbox:
filter = st.radio("Style your sofa with:", options)
col1, col2 = st.columns([0.5, 0.5])
if filter == "Style Transformer":
content = Image.open("figures/sofa_example1.jpg")
style = Image.open("figures/style_example1.jpg")
output = "figures/0.png"
with col1:
st.markdown(
'<p style="text-align: center;">Before</p>', unsafe_allow_html=True
)
st.image(content, width=300)
st.image(style, width=300)
with col2:
st.markdown(
'<p style="text-align: center;">After</p>', unsafe_allow_html=True
)
st.image(output, width=300)
elif filter == "StyleFAST":
content = Image.open("figures/sofa_example3.jpg")
style = Image.open("figures/style_example10.jpg")
output = "figures/1.png"
with col1:
st.markdown(
'<p style="text-align: center;">Before</p>', unsafe_allow_html=True
)
st.image(content, width=300)
st.image(style, width=300)
with col2:
st.markdown(
'<p style="text-align: center;">After</p>', unsafe_allow_html=True
)
st.image(output, width=300)
elif filter == "Style Projection":
content = Image.open("figures/sofa_example2.jpg")
style = Image.open("figures/style_example6.jpg")
output = "figures/2.png"
with col1:
st.markdown(
'<p style="text-align: center;">Before</p>', unsafe_allow_html=True
)
st.image(content, width=300)
st.image(style, width=300)
with col2:
st.markdown(
'<p style="text-align: center;">After</p>', unsafe_allow_html=True
)
st.image(output, width=300)
# Add 'before' and 'after' columns
elif (uploaded_content is not None) & (uploaded_style is not None):
content = fix_orient(Image.open(uploaded_content))
style = fix_orient(Image.open(uploaded_style))
filter = st.radio("Style your sofa with:", options)
ETA = "Unknown"
if filter == "Style Transformer":
ETA = "50s with CPU, 9s with GPU"
elif filter == "StyleFAST":
ETA = "15s with CPU, 3s with GPU"
elif filter == "Style Projection":
alpha = st.slider(
"Adjust the weight of the image vs style", 0.0, 1.0, 0.8, step=0.1
)
ETA = "20s with CPU, 10s with GPU"
st.info("Estimated processing time: " + ETA)
button = st.button("Style my sofa")
col1, col2 = st.columns([0.5, 0.5])
with col1:
st.markdown('<p style="text-align: center;">Before</p>', unsafe_allow_html=True)
st.image(content, width=300)
st.image(style, width=300)
if button:
with col2:
st.markdown(
'<p style="text-align: center;">After</p>', unsafe_allow_html=True
)
with st.spinner("Preprocessing images..."):
# preprocess input images to be (640,640) squares
# to fit requirements of the segmentation model
resized_img, box = resize_sofa(content)
resized_style = resize_style(style)
# generate mask for image
with st.spinner("generating mask..."):
mask = get_mask(resized_img)
if filter == "Style Transformer":
# Created a styled sofa
with st.spinner("Styling sofa..."):
styled_sofa = StyleTransformer(resized_img, resized_style)
# postprocess the final image
with st.spinner("Replacing sofa..."):
new_sofa = replace_sofa(resized_img, mask, styled_sofa)
new_sofa = new_sofa.crop(box)
st.balloons()
st.image(new_sofa, width=300)
elif filter == "StyleFAST":
# Created a styled sofa
with st.spinner("Styling sofa..."):
styled_sofa = StyleFAST(resized_img, resized_style)
# postprocess the final image
with st.spinner("Replacing sofa..."):
new_sofa = replace_sofa(resized_img, mask, styled_sofa)
new_sofa = new_sofa.crop(box)
st.balloons()
st.image(new_sofa, width=300)
elif filter == "Style Projection":
# Created a styled sofa
with st.spinner("Styling sofa..."):
styled_sofa = styleProjection(resized_img, resized_style, alpha)
# postprocess the final image
with st.spinner("Replacing sofa..."):
new_sofa = replace_sofa(resized_img, mask, styled_sofa)
new_sofa = new_sofa.crop(box)
st.balloons()
st.image(new_sofa, width=300)
else:
st.image(image, width=300)
# Add a feedback section in the sidebar
st.sidebar.title(" ") # create space
st.sidebar.markdown(" ")
st.sidebar.subheader("Please help us improve!")
with st.sidebar.form(key="columns_in_form", clear_on_submit=True):
rating = st.slider(
"Please rate the app",
min_value=1,
max_value=5,
value=3,
help="Drag the slider to rate the app."
+ "This is a 1-5 rating scale where 5 is the highest rating",
)
text = st.text_input(label="Please leave your feedback here")
submitted = st.form_submit_button("Submit")
if submitted:
st.write("Thanks for your feedback!")
st.markdown("Your Rating:")
st.markdown(rating)
st.markdown("Your Feedback:")
st.markdown(text)