import streamlit as st from PIL import Image from io import BytesIO from collections import namedtuple import numpy as np from src.simswap import SimSwap def run(model): id_image = None attr_image = None specific_image = None output = None def get_np_image(file): return np.array(Image.open(file))[:, :, :3] with st.sidebar: uploaded_file = st.file_uploader("Select an ID image") if uploaded_file is not None: id_image = get_np_image(uploaded_file) uploaded_file = st.file_uploader("Select an Attribute image") if uploaded_file is not None: attr_image = get_np_image(uploaded_file) uploaded_file = st.file_uploader("Select a specific person image (Optional)") if uploaded_file is not None: specific_image = get_np_image(uploaded_file) face_alignment_type = st.radio("Face alignment type:", ("none", "ffhq")) enhance_output = st.radio("Enhance output:", ("yes", "no")) smooth_mask_iter = st.slider( label="smooth_mask_iter", min_value=1, max_value=60, step=1, value=7 ) smooth_mask_kernel_size = st.slider( label="smooth_mask_kernel_size", min_value=1, max_value=61, step=2, value=17 ) smooth_mask_threshold = st.slider(label="smooth_mask_threshold", min_value=0.01, max_value=1.0, step=0.01, value=0.9) specific_latent_match_threshold = st.slider( label="specific_latent_match_threshold", min_value=0.0, max_value=10.0, value=0.05, ) num_cols = sum( (id_image is not None, attr_image is not None, specific_image is not None) ) cols = st.columns(num_cols if num_cols > 0 else 1) i = 0 if id_image is not None: with cols[i]: i += 1 st.header("ID image") st.image(id_image) if attr_image is not None: with cols[i]: i += 1 st.header("Attribute image") st.image(attr_image) if specific_image is not None: with cols[i]: st.header("Specific image") st.image(specific_image) if id_image is not None and attr_image is not None: model.set_face_alignment_type(face_alignment_type) model.set_smooth_mask_iter(smooth_mask_iter) model.set_smooth_mask_kernel_size(smooth_mask_kernel_size) model.set_smooth_mask_threshold(smooth_mask_threshold) model.set_specific_latent_match_threshold(specific_latent_match_threshold) model.enhance_output = True if enhance_output == "yes" else False model.specific_latent = None model.specific_id_image = specific_image if specific_image is not None else None model.id_latent = None model.id_image = id_image output = model(attr_image) if output is not None: with st.container(): st.header("SimSwap output") st.image(output) output_to_download = Image.fromarray(output.astype("uint8"), "RGB") buf = BytesIO() output_to_download.save(buf, format="JPEG") st.download_button( label="Download", data=buf.getvalue(), file_name="output.jpg", mime="image/jpeg", ) @st.cache(allow_output_mutation=True) def load_model(config): return SimSwap( config=config, id_image=None, specific_image=None, ) # TODO: remove it and use config files from 'configs' Config = namedtuple( "Config", "face_detector_weights" + " face_id_weights" + " parsing_model_weights" + " simswap_weights" + " gfpgan_weights" + " blend_module_weights" + " device" + " crop_size" + " checkpoint_type" + " face_alignment_type" + " smooth_mask_iter" + " smooth_mask_kernel_size" + " smooth_mask_threshold" + " face_detector_threshold" + " specific_latent_match_threshold" + " enhance_output", ) if __name__ == "__main__": config = Config( face_detector_weights="weights/scrfd_10g_bnkps.onnx", face_id_weights="weights/arcface_net.jit", parsing_model_weights="weights/79999_iter.pth", simswap_weights="weights/latest_net_G.pth", gfpgan_weights="weights/GFPGANv1.4_ema.pth", blend_module_weights="weights/blend.jit", device="cuda", crop_size=224, checkpoint_type="official_224", face_alignment_type="none", smooth_mask_iter=7, smooth_mask_kernel_size=17, smooth_mask_threshold=0.9, face_detector_threshold=0.6, specific_latent_match_threshold=0.05, enhance_output=True ) model = load_model(config) run(model)