CycleGAN / app.py
Yanguan's picture
1
c1c16cb
raw
history blame contribute delete
No virus
2.74 kB
# 2023年2月23日
"""
实现web界面
>>> streamlit run app.py
"""
from io import BytesIO
from pathlib import Path
import streamlit as st
from detect import detect, opt
from PIL import Image
from util import get_all_weights
"""
# CycleGAN
功能:上传本地文件、选择转换风格
"""
def load_css(css_path="./util/streamlit/css.css"):
"""
加载CSS文件
:param css_path: CSS文件路径
"""
if Path(css_path).exists():
with open(css_path) as f:
# 将CSS文件内容插入到HTML中
st.markdown(
f"""<style>{f.read()}</style>""",
unsafe_allow_html=True,
)
def load_img_file(file):
"""读取图片文件"""
img = Image.open(BytesIO(file.read()))
st.image(img, use_column_width=True) # 显示图片
return img
def set_style_options(label: str, frame=st):
"""风格选项"""
style_options = get_all_weights()
options = [None] + style_options # 默认空
style_param = frame.selectbox(label=label, options=options)
return style_param
# load_css()
tab_mul2mul, tab_mul2one, tab_set = st.tabs(["多图多风格转换", "多图同风格转换", "参数"])
with tab_mul2mul:
uploaded_files = st.file_uploader(label="选择本地图片", accept_multiple_files=True, key=1)
if uploaded_files:
for idx, uploaded_file in enumerate(uploaded_files):
colL, colR = st.columns(2)
with colL:
img = load_img_file(uploaded_file)
style = set_style_options(label=str(uploaded_file), frame=st)
with colR:
if style:
fake_img = detect(img=img, style=style)
st.image(fake_img, caption="", use_column_width=True)
with tab_set:
colL, colR = st.columns([1, 3])
for k, v in sorted(vars(opt).items()):
st.text_input(label=k, value=v, disabled=True)
# st.selectbox("ss", options=opt.parse_args())
confidence_threshold = st.slider("Confidence threshold", 0.0, 1.0, 0.5, 0.01)
opt.no_dropout = st.radio("no_droput", [True, False])
with tab_mul2one:
uploaded_files = st.file_uploader(label="选择本地图片", accept_multiple_files=True, key=2)
if uploaded_files:
colL, colR = st.columns(2)
with colL:
imgs = [load_img_file(ii) for ii in uploaded_files]
with colR:
style = set_style_options(label="选择风格", frame=st)
if style:
if st.button("♻️风格转换", use_container_width=True):
for img in imgs:
fake_img = detect(img, style)
st.image(fake_img, caption="", use_column_width=True)