Spaces:
Sleeping
Sleeping
import gradio as gr | |
import legacy | |
import dnnlib | |
import numpy as np | |
import torch | |
import find_direction | |
import generator | |
import psp_wrapper | |
psp_encoder_path = "./pretrained/e4e_ffhq_encode.pt" | |
landmarks_path = "./pretrained/shape_predictor_68_face_landmarks.dat" | |
e4e_embedder = psp_wrapper.psp_encoder(psp_encoder_path, landmarks_path) | |
G_ffhq_path = "./pretrained/ffhq.pkl" | |
G_metfaces_path = "./pretrained/metfaces.pkl" | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
with dnnlib.util.open_url(G_ffhq_path) as f: | |
G_ffhq = legacy.load_network_pkl(f)['G_ema'].to(device) | |
with dnnlib.util.open_url(G_metfaces_path) as f: | |
G_metfaces = legacy.load_network_pkl(f)['G_ema'].to(device) | |
G_dict = {"FFHQ": G_ffhq, "MetFaces": G_metfaces} | |
DESCRIPTION = '''# <a href="https://github.com/catlab-team/stylemc"> StyleMC:</a> Multi-Channel Based Fast Text-Guided Image Generation and Manipulation | |
''' | |
FOOTER = 'This space is built by <a href = "https://github.com/catlab-team">Catlab Team</a>.' | |
direction_map = {} | |
direction_list = [] | |
def add_direction(prompt, stylegan_type, id_loss_w): | |
new_dir_name = prompt+" "+stylegan_type+" w_id_loss"+str(id_loss_w) | |
if (prompt != None) and (new_dir_name not in direction_list): | |
print("adding direction with id:", new_dir_name) | |
direction = find_direction.find_direction(G_dict[stylegan_type], prompt) | |
print(f"new direction calculated with {stylegan_type} and id loss weight = {id_loss_w}") | |
direction_list.append(new_dir_name) | |
direction_map[new_dir_name]={"direction":direction, "stylegan_type":stylegan_type} | |
return gr.Radio.update(choices=direction_list, value=None, visible=True) | |
def generate_output_image(image_path, direction_id, change_power): | |
direction = direction_map[direction_id]["direction"] | |
G=G_dict["FFHQ"] | |
w = e4e_embedder.get_w(image_path) | |
s = generator.w_to_s(GIn=G, wsIn=w) | |
output_image = generator.generate_from_style( | |
GIn=G, | |
styles=s, | |
styles_direction=direction, | |
change_power=change_power, | |
outdir='.' | |
) | |
return output_image | |
with gr.Blocks(css="style.css") as demo: | |
gr.Markdown(DESCRIPTION) | |
with gr.Box(): | |
gr.Markdown('''## Step 1 (Finding a global manipulation direction) - Please enter the target **text prompt** and **identity loss weight** to find global manipulation direction: - Hit the **Find Direction** button.''') | |
with gr.Row(): | |
with gr.Column(): | |
style_gan_type = gr.Radio(["FFHQ", "MetFaces"], value = "FFHQ", label="StyleGAN Type", interactive=True) | |
with gr.Column(): | |
identity_loss_weight = gr.Slider( | |
0.1, 10, value=0.5, step=0.1,label="Identity Loss Weight",interactive=True | |
) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
text = gr.Textbox( | |
label="Enter your prompt", | |
show_label=False, | |
max_lines=1, | |
placeholder="Enter your prompt" | |
).style(container=False) | |
find_direction_btn = gr.Button("Find Direction").style(full_width=False) | |
with gr.Box(): | |
gr.Markdown('''## Step 2 (Manipulation) - Please upload an image for manipulation: - You can also select the **previous directions** and determine the **manipulation strength**. - Hit the **Generate** button.''') | |
with gr.Row(): | |
direction_radio = gr.Radio(direction_list, label="List of Directions") | |
with gr.Row(): | |
manipulation_strength = gr.Slider( | |
0.1, 100, value=25, step=0.1, label="Manipulation Strength",interactive=True | |
) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
input_image = gr.Image(label="Input Image", type="filepath") | |
with gr.Row(): | |
generate_btn = gr.Button("Generate") | |
with gr.Column(): | |
with gr.Row(): | |
generated_image = gr.Image(label="Generated Image",type="pil",interactive=False) | |
find_direction_btn.click(add_direction, inputs=[text, style_gan_type, identity_loss_weight], outputs=direction_radio) | |
generate_btn.click(generate_output_image, inputs=[input_image, direction_radio,manipulation_strength], outputs=generated_image) | |
demo.launch(debug=True) | |