stylemc-demo / app.py
adirik's picture
update app
b5e8b97
raw
history blame
7.95 kB
import gradio as gr
import legacy
import dnnlib
import numpy as np
import torch
from find_direction import find_direction
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
with dnnlib.util.open_url("./pretrained/ffhq.pkl") as f:
G = legacy.load_network_pkl(f)['G_ema'].to(device)
DESCRIPTION = '''# <a href="https://github.com/catlab-team/stylemc"> StyleMC:</a> Multi-Channel Based Fast Text-Guided Image Generation and Manipulation
'''
FOOTER = 'This space is built by <a href = "https://github.com/catlab-team">Catlab Team</a>.'
def main():
with gr.Blocks(css='style.css') as demo:
gr.Markdown(DESCRIPTION)
with gr.Box():
gr.Markdown('''## Step 1 (Finding a global manipulation direction)
- Please enter the target **text prompt** and **identity loss weight** to find global manipulation direction:
- Hit the **Find Direction** button.
''')
with gr.Row():
with gr.Column():
with gr.Row():
text = gr.Textbox(
label="Enter your prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
).style(
container=False,
)
identity_loss_weight = gr.Slider(0.1,
10,
value=0.5,
step=0.1,
label='Identity Loss Weight',
interactive=True)
btn = gr.Button("Find Direction").style(full_width=False)
with gr.Box():
gr.Markdown('''## Step 2 (Manipulation)
- Please upload an image for manipulation:
- You can also select the **previous directions** and determine the **manipulation strength**.
- Hit the **Generate** button.
''')
with gr.Row():
identity_loss_weight = gr.Slider(0.1,
100,
value=50,
step=0.1,
label='Manipulation Strength',
interactive=True)
with gr.Row():
with gr.Column():
with gr.Row():
input_image = gr.Image(label='Input Image',
type='filepath')
with gr.Row():
generate_button = gr.Button('Generate')
with gr.Column():
with gr.Row():
generated_image = gr.Image(label='Generated Image',
type='numpy',
interactive=False)
# with gr.Box():
# gr.Markdown('''## Step 2 (Select Style Image)
# - Select **Style Type**.
# - Select **Style Image Index** from the image table below.
# ''')
# with gr.Row():
# with gr.Column():
# style_type = gr.Radio(model.style_types,
# label='Style Type')
# text = get_style_image_markdown_text('cartoon')
# style_image = gr.Markdown(value=text)
# style_index = gr.Slider(0,
# 316,
# value=26,
# step=1,
# label='Style Image Index')
# with gr.Row():
# example_styles = gr.Dataset(
# components=[style_type, style_index],
# samples=[
# ['cartoon', 26],
# ['caricature', 65],
# ['arcane', 63],
# ['pixar', 80],
# ])
# with gr.Box():
# gr.Markdown('''## Step 3 (Generate Style Transferred Image)
# - Adjust **Structure Weight** and **Color Weight**.
# - These are weights for the style image, so the larger the value, the closer the resulting image will be to the style image.
# - Hit the **Generate** button.
# ''')
# with gr.Row():
# with gr.Column():
# with gr.Row():
# structure_weight = gr.Slider(0,
# 1,
# value=0.6,
# step=0.1,
# label='Structure Weight')
# with gr.Row():
# color_weight = gr.Slider(0,
# 1,
# value=1,
# step=0.1,
# label='Color Weight')
# with gr.Row():
# structure_only = gr.Checkbox(label='Structure Only')
# with gr.Row():
# generate_button = gr.Button('Generate')
# with gr.Column():
# result = gr.Image(label='Result')
# with gr.Row():
# example_weights = gr.Dataset(
# components=[structure_weight, color_weight],
# samples=[
# [0.6, 1.0],
# [0.3, 1.0],
# [0.0, 1.0],
# [1.0, 0.0],
# ])
gr.Markdown(FOOTER)
# preprocess_button.click(fn=model.detect_and_align_face,
# inputs=input_image,
# outputs=aligned_face)
# aligned_face.change(fn=model.reconstruct_face,
# inputs=aligned_face,
# outputs=[
# reconstructed_face,
# instyle,
# ])
# style_type.change(fn=update_slider,
# inputs=style_type,
# outputs=style_index)
# style_type.change(fn=update_style_image,
# inputs=style_type,
# outputs=style_image)
# generate_button.click(fn=model.generate,
# inputs=[
# style_type,
# style_index,
# structure_weight,
# color_weight,
# structure_only,
# instyle,
# ],
# outputs=result)
# example_images.click(fn=set_example_image,
# inputs=example_images,
# outputs=example_images.components)
# example_styles.click(fn=set_example_styles,
# inputs=example_styles,
# outputs=example_styles.components)
# example_weights.click(fn=set_example_weights,
# inputs=example_weights,
# outputs=example_weights.components)
demo.launch(
# enable_queue=args.enable_queue,
# server_port=args.port,
# share=args.share,
)
if __name__ == '__main__':
main()