import gradio as gr import legacy import dnnlib import numpy as np import torch from find_direction import find_direction device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") with dnnlib.util.open_url("./pretrained/ffhq.pkl") as f: G = legacy.load_network_pkl(f)['G_ema'].to(device) DESCRIPTION = '''# StyleMC: Multi-Channel Based Fast Text-Guided Image Generation and Manipulation ''' FOOTER = 'This space is built by Catlab Team.' def main(): with gr.Blocks(css='style.css') as demo: gr.Markdown(DESCRIPTION) with gr.Box(): gr.Markdown('''## Step 1 (Finding a global manipulation direction) - Please enter the target **text prompt** and **identity loss weight** to find global manipulation direction: - Hit the **Find Direction** button. ''') with gr.Row(): with gr.Column(): with gr.Row(): text = gr.Textbox( label="Enter your prompt", show_label=False, max_lines=1, placeholder="Enter your prompt", ).style( container=False, ) identity_loss_weight = gr.Slider(0.1, 10, value=0.5, step=0.1, label='Identity Loss Weight', interactive=True) btn = gr.Button("Find Direction").style(full_width=False) with gr.Box(): gr.Markdown('''## Step 2 (Manipulation) - Please upload an image for manipulation: - You can also select the **previous directions** and determine the **manipulation strength**. - Hit the **Generate** button. ''') with gr.Row(): identity_loss_weight = gr.Slider(0.1, 100, value=50, step=0.1, label='Manipulation Strength', interactive=True) with gr.Row(): with gr.Column(): with gr.Row(): input_image = gr.Image(label='Input Image', type='filepath') with gr.Row(): generate_button = gr.Button('Generate') with gr.Column(): with gr.Row(): generated_image = gr.Image(label='Generated Image', type='numpy', interactive=False) # with gr.Box(): # gr.Markdown('''## Step 2 (Select Style Image) # - Select **Style Type**. # - Select **Style Image Index** from the image table below. # ''') # with gr.Row(): # with gr.Column(): # style_type = gr.Radio(model.style_types, # label='Style Type') # text = get_style_image_markdown_text('cartoon') # style_image = gr.Markdown(value=text) # style_index = gr.Slider(0, # 316, # value=26, # step=1, # label='Style Image Index') # with gr.Row(): # example_styles = gr.Dataset( # components=[style_type, style_index], # samples=[ # ['cartoon', 26], # ['caricature', 65], # ['arcane', 63], # ['pixar', 80], # ]) # with gr.Box(): # gr.Markdown('''## Step 3 (Generate Style Transferred Image) # - Adjust **Structure Weight** and **Color Weight**. # - These are weights for the style image, so the larger the value, the closer the resulting image will be to the style image. # - Hit the **Generate** button. # ''') # with gr.Row(): # with gr.Column(): # with gr.Row(): # structure_weight = gr.Slider(0, # 1, # value=0.6, # step=0.1, # label='Structure Weight') # with gr.Row(): # color_weight = gr.Slider(0, # 1, # value=1, # step=0.1, # label='Color Weight') # with gr.Row(): # structure_only = gr.Checkbox(label='Structure Only') # with gr.Row(): # generate_button = gr.Button('Generate') # with gr.Column(): # result = gr.Image(label='Result') # with gr.Row(): # example_weights = gr.Dataset( # components=[structure_weight, color_weight], # samples=[ # [0.6, 1.0], # [0.3, 1.0], # [0.0, 1.0], # [1.0, 0.0], # ]) gr.Markdown(FOOTER) # preprocess_button.click(fn=model.detect_and_align_face, # inputs=input_image, # outputs=aligned_face) # aligned_face.change(fn=model.reconstruct_face, # inputs=aligned_face, # outputs=[ # reconstructed_face, # instyle, # ]) # style_type.change(fn=update_slider, # inputs=style_type, # outputs=style_index) # style_type.change(fn=update_style_image, # inputs=style_type, # outputs=style_image) # generate_button.click(fn=model.generate, # inputs=[ # style_type, # style_index, # structure_weight, # color_weight, # structure_only, # instyle, # ], # outputs=result) # example_images.click(fn=set_example_image, # inputs=example_images, # outputs=example_images.components) # example_styles.click(fn=set_example_styles, # inputs=example_styles, # outputs=example_styles.components) # example_weights.click(fn=set_example_weights, # inputs=example_weights, # outputs=example_weights.components) demo.launch( # enable_queue=args.enable_queue, # server_port=args.port, # share=args.share, ) if __name__ == '__main__': main()