File size: 3,709 Bytes
cb4af26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#!/usr/bin/env python

from __future__ import annotations

import pathlib

import gradio as gr

from model import Model

DESCRIPTION = '''# [HairCLIP](https://github.com/wty-ustc/HairCLIP)

<center><img id="teaser" src="https://raw.githubusercontent.com/wty-ustc/HairCLIP/main/assets/teaser.png" alt="teaser"></center>
'''


def load_hairstyle_list() -> list[str]:
    with open('HairCLIP/mapper/hairstyle_list.txt') as f:
        lines = [line.strip() for line in f.readlines()]
        lines = [line[:-10] for line in lines]
    return lines


def set_example_image(example: list) -> dict:
    return gr.Image.update(value=example[0])


def update_step2_components(choice: str) -> tuple[dict, dict]:
    return (
        gr.Dropdown.update(visible=choice in ['hairstyle', 'both']),
        gr.Textbox.update(visible=choice in ['color', 'both']),
    )


model = Model()

with gr.Blocks(css='style.css') as demo:
    gr.Markdown(DESCRIPTION)
    with gr.Box():
        gr.Markdown('## Step 1')
        with gr.Row():
            with gr.Column():
                with gr.Row():
                    input_image = gr.Image(label='Input Image',
                                           type='filepath')
                with gr.Row():
                    preprocess_button = gr.Button('Preprocess')
            with gr.Column():
                aligned_face = gr.Image(label='Aligned Face',
                                        type='pil',
                                        interactive=False)
            with gr.Column():
                reconstructed_face = gr.Image(label='Reconstructed Face',
                                              type='numpy')
                latent = gr.Variable()

        with gr.Row():
            paths = sorted(pathlib.Path('images').glob('*.jpg'))
            gr.Examples(examples=[[path.as_posix()] for path in paths],
                        inputs=input_image)

    with gr.Box():
        gr.Markdown('## Step 2')
        with gr.Row():
            with gr.Column():
                with gr.Row():
                    editing_type = gr.Radio(
                        label='Editing Type',
                        choices=['hairstyle', 'color', 'both'],
                        value='both',
                        type='value')
                with gr.Row():
                    hairstyles = load_hairstyle_list()
                    hairstyle_index = gr.Dropdown(label='Hairstyle',
                                                  choices=hairstyles,
                                                  value='afro',
                                                  type='index')
                with gr.Row():
                    color_description = gr.Textbox(label='Color', value='red')
                with gr.Row():
                    run_button = gr.Button('Run')

            with gr.Column():
                result = gr.Image(label='Result')

    preprocess_button.click(fn=model.detect_and_align_face,
                            inputs=input_image,
                            outputs=aligned_face)
    aligned_face.change(fn=model.reconstruct_face,
                        inputs=aligned_face,
                        outputs=[reconstructed_face, latent])
    editing_type.change(fn=update_step2_components,
                        inputs=editing_type,
                        outputs=[hairstyle_index, color_description])
    run_button.click(fn=model.generate,
                     inputs=[
                         editing_type,
                         hairstyle_index,
                         color_description,
                         latent,
                     ],
                     outputs=result)

demo.queue(max_size=10).launch()