simswap55 / app_web.py
LB5's picture
Upload 45 files
22b8701
raw
history blame contribute delete
No virus
4.83 kB
import streamlit as st
from PIL import Image
from io import BytesIO
from collections import namedtuple
import numpy as np
from src.simswap import SimSwap
def run(model):
id_image = None
attr_image = None
specific_image = None
output = None
def get_np_image(file):
return np.array(Image.open(file))[:, :, :3]
with st.sidebar:
uploaded_file = st.file_uploader("Select an ID image")
if uploaded_file is not None:
id_image = get_np_image(uploaded_file)
uploaded_file = st.file_uploader("Select an Attribute image")
if uploaded_file is not None:
attr_image = get_np_image(uploaded_file)
uploaded_file = st.file_uploader("Select a specific person image (Optional)")
if uploaded_file is not None:
specific_image = get_np_image(uploaded_file)
face_alignment_type = st.radio("Face alignment type:", ("none", "ffhq"))
enhance_output = st.radio("Enhance output:", ("yes", "no"))
smooth_mask_iter = st.slider(
label="smooth_mask_iter", min_value=1, max_value=60, step=1, value=7
)
smooth_mask_kernel_size = st.slider(
label="smooth_mask_kernel_size", min_value=1, max_value=61, step=2, value=17
)
smooth_mask_threshold = st.slider(label="smooth_mask_threshold", min_value=0.01, max_value=1.0, step=0.01, value=0.9)
specific_latent_match_threshold = st.slider(
label="specific_latent_match_threshold",
min_value=0.0,
max_value=10.0,
value=0.05,
)
num_cols = sum(
(id_image is not None, attr_image is not None, specific_image is not None)
)
cols = st.columns(num_cols if num_cols > 0 else 1)
i = 0
if id_image is not None:
with cols[i]:
i += 1
st.header("ID image")
st.image(id_image)
if attr_image is not None:
with cols[i]:
i += 1
st.header("Attribute image")
st.image(attr_image)
if specific_image is not None:
with cols[i]:
st.header("Specific image")
st.image(specific_image)
if id_image is not None and attr_image is not None:
model.set_face_alignment_type(face_alignment_type)
model.set_smooth_mask_iter(smooth_mask_iter)
model.set_smooth_mask_kernel_size(smooth_mask_kernel_size)
model.set_smooth_mask_threshold(smooth_mask_threshold)
model.set_specific_latent_match_threshold(specific_latent_match_threshold)
model.enhance_output = True if enhance_output == "yes" else False
model.specific_latent = None
model.specific_id_image = specific_image if specific_image is not None else None
model.id_latent = None
model.id_image = id_image
output = model(attr_image)
if output is not None:
with st.container():
st.header("SimSwap output")
st.image(output)
output_to_download = Image.fromarray(output.astype("uint8"), "RGB")
buf = BytesIO()
output_to_download.save(buf, format="JPEG")
st.download_button(
label="Download",
data=buf.getvalue(),
file_name="output.jpg",
mime="image/jpeg",
)
@st.cache(allow_output_mutation=True)
def load_model(config):
return SimSwap(
config=config,
id_image=None,
specific_image=None,
)
# TODO: remove it and use config files from 'configs'
Config = namedtuple(
"Config",
"face_detector_weights"
+ " face_id_weights"
+ " parsing_model_weights"
+ " simswap_weights"
+ " gfpgan_weights"
+ " blend_module_weights"
+ " device"
+ " crop_size"
+ " checkpoint_type"
+ " face_alignment_type"
+ " smooth_mask_iter"
+ " smooth_mask_kernel_size"
+ " smooth_mask_threshold"
+ " face_detector_threshold"
+ " specific_latent_match_threshold"
+ " enhance_output",
)
if __name__ == "__main__":
config = Config(
face_detector_weights="weights/scrfd_10g_bnkps.onnx",
face_id_weights="weights/arcface_net.jit",
parsing_model_weights="weights/79999_iter.pth",
simswap_weights="weights/latest_net_G.pth",
gfpgan_weights="weights/GFPGANv1.4_ema.pth",
blend_module_weights="weights/blend.jit",
device="cuda",
crop_size=224,
checkpoint_type="official_224",
face_alignment_type="none",
smooth_mask_iter=7,
smooth_mask_kernel_size=17,
smooth_mask_threshold=0.9,
face_detector_threshold=0.6,
specific_latent_match_threshold=0.05,
enhance_output=True
)
model = load_model(config)
run(model)