Spaces:
Configuration error
Configuration error
| import streamlit as st | |
| from PIL import Image | |
| from io import BytesIO | |
| from collections import namedtuple | |
| import numpy as np | |
| from src.simswap import SimSwap | |
| def run(model): | |
| id_image = None | |
| attr_image = None | |
| specific_image = None | |
| output = None | |
| def get_np_image(file): | |
| return np.array(Image.open(file))[:, :, :3] | |
| with st.sidebar: | |
| uploaded_file = st.file_uploader("Select an ID image") | |
| if uploaded_file is not None: | |
| id_image = get_np_image(uploaded_file) | |
| uploaded_file = st.file_uploader("Select an Attribute image") | |
| if uploaded_file is not None: | |
| attr_image = get_np_image(uploaded_file) | |
| uploaded_file = st.file_uploader("Select a specific person image (Optional)") | |
| if uploaded_file is not None: | |
| specific_image = get_np_image(uploaded_file) | |
| face_alignment_type = st.radio("Face alignment type:", ("none", "ffhq")) | |
| enhance_output = st.radio("Enhance output:", ("yes", "no")) | |
| smooth_mask_iter = st.slider( | |
| label="smooth_mask_iter", min_value=1, max_value=60, step=1, value=7 | |
| ) | |
| smooth_mask_kernel_size = st.slider( | |
| label="smooth_mask_kernel_size", min_value=1, max_value=61, step=2, value=17 | |
| ) | |
| smooth_mask_threshold = st.slider(label="smooth_mask_threshold", min_value=0.01, max_value=1.0, step=0.01, value=0.9) | |
| specific_latent_match_threshold = st.slider( | |
| label="specific_latent_match_threshold", | |
| min_value=0.0, | |
| max_value=10.0, | |
| value=0.05, | |
| ) | |
| num_cols = sum( | |
| (id_image is not None, attr_image is not None, specific_image is not None) | |
| ) | |
| cols = st.columns(num_cols if num_cols > 0 else 1) | |
| i = 0 | |
| if id_image is not None: | |
| with cols[i]: | |
| i += 1 | |
| st.header("ID image") | |
| st.image(id_image) | |
| if attr_image is not None: | |
| with cols[i]: | |
| i += 1 | |
| st.header("Attribute image") | |
| st.image(attr_image) | |
| if specific_image is not None: | |
| with cols[i]: | |
| st.header("Specific image") | |
| st.image(specific_image) | |
| if id_image is not None and attr_image is not None: | |
| model.set_face_alignment_type(face_alignment_type) | |
| model.set_smooth_mask_iter(smooth_mask_iter) | |
| model.set_smooth_mask_kernel_size(smooth_mask_kernel_size) | |
| model.set_smooth_mask_threshold(smooth_mask_threshold) | |
| model.set_specific_latent_match_threshold(specific_latent_match_threshold) | |
| model.enhance_output = True if enhance_output == "yes" else False | |
| model.specific_latent = None | |
| model.specific_id_image = specific_image if specific_image is not None else None | |
| model.id_latent = None | |
| model.id_image = id_image | |
| output = model(attr_image) | |
| if output is not None: | |
| with st.container(): | |
| st.header("SimSwap output") | |
| st.image(output) | |
| output_to_download = Image.fromarray(output.astype("uint8"), "RGB") | |
| buf = BytesIO() | |
| output_to_download.save(buf, format="JPEG") | |
| st.download_button( | |
| label="Download", | |
| data=buf.getvalue(), | |
| file_name="output.jpg", | |
| mime="image/jpeg", | |
| ) | |
| def load_model(config): | |
| return SimSwap( | |
| config=config, | |
| id_image=None, | |
| specific_image=None, | |
| ) | |
| # TODO: remove it and use config files from 'configs' | |
| Config = namedtuple( | |
| "Config", | |
| "face_detector_weights" | |
| + " face_id_weights" | |
| + " parsing_model_weights" | |
| + " simswap_weights" | |
| + " gfpgan_weights" | |
| + " blend_module_weights" | |
| + " device" | |
| + " crop_size" | |
| + " checkpoint_type" | |
| + " face_alignment_type" | |
| + " smooth_mask_iter" | |
| + " smooth_mask_kernel_size" | |
| + " smooth_mask_threshold" | |
| + " face_detector_threshold" | |
| + " specific_latent_match_threshold" | |
| + " enhance_output", | |
| ) | |
| if __name__ == "__main__": | |
| config = Config( | |
| face_detector_weights="weights/scrfd_10g_bnkps.onnx", | |
| face_id_weights="weights/arcface_net.jit", | |
| parsing_model_weights="weights/79999_iter.pth", | |
| simswap_weights="weights/latest_net_G.pth", | |
| gfpgan_weights="weights/GFPGANv1.4_ema.pth", | |
| blend_module_weights="weights/blend.jit", | |
| device="cuda", | |
| crop_size=224, | |
| checkpoint_type="official_224", | |
| face_alignment_type="none", | |
| smooth_mask_iter=7, | |
| smooth_mask_kernel_size=17, | |
| smooth_mask_threshold=0.9, | |
| face_detector_threshold=0.6, | |
| specific_latent_match_threshold=0.05, | |
| enhance_output=True | |
| ) | |
| model = load_model(config) | |
| run(model) | |