File size: 12,100 Bytes
7981808
8ae354f
7981808
7741868
 
efd5c2d
8430b83
 
 
7741868
 
 
 
 
 
 
a65915c
 
3adb51d
7741868
 
8430b83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3adb51d
 
 
 
 
 
 
 
 
 
8430b83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7741868
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a65915c
 
9cb47b2
7741868
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc5986a
 
 
 
 
7741868
bc5986a
 
7741868
f1e4c23
 
 
a65915c
7741868
bc5986a
7741868
bc5986a
a65915c
bc5986a
 
 
7741868
 
c3fedca
9cb47b2
a65915c
 
 
 
 
 
 
b92baac
a65915c
 
 
 
 
 
 
 
 
 
 
 
b92baac
a65915c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b92baac
c3fedca
b92baac
13def2b
b92baac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed0c2a9
b35f032
b92baac
b35f032
 
9cb47b2
b35f032
 
 
 
c3fedca
3d2dfd7
14dd124
c3fedca
b35f032
3823318
c3fedca
9cb47b2
c3fedca
8283c4c
13def2b
8ae354f
937a93c
 
 
 
8ae354f
b2d39c4
9cb47b2
c3fedca
b35f032
3d2dfd7
b35f032
9cb47b2
c3fedca
8283c4c
13def2b
c3fedca
 
8ae354f
c3fedca
 
 
 
 
8ae354f
c3fedca
 
 
 
3d2dfd7
b35f032
 
9cb47b2
ed0c2a9
 
b35f032
3d43b70
7741868
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
import gradio as gr
from PIL import Image

import os
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.modules.stylegan2.model import StyledConv, ToRGB, EqualLinear, ResBlock, ConvLayer, PixelNorm

from utils.util import *
from utils.data_utils import Transforms
from data import CustomDataLoader
from data.super_dataset import SuperDataset
from configs import parse_config
from utils.augmentation import ImagePathToImage
import clip
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, InterpolationMode
from models.style_based_pix2pixII_model import CLIPFeats2Wplus


class Stylizer(nn.Module):

    def __init__(self, ngf=64, phase=2, model_weights=None):
        super(Stylizer, self).__init__()

        # encoder
        self.encoder = nn.Sequential(
            ConvLayer(3, ngf, 3),          # 512
            ResBlock(ngf * 1, ngf * 1),    # 256
            ResBlock(ngf * 1, ngf * 2),    # 128
            ResBlock(ngf * 2, ngf * 4),    # 64
            ResBlock(ngf * 4, ngf * 8),    # 32
            ConvLayer(ngf * 8, ngf * 8, 3) # 32
        )

        # mapping network
        self.mapping_z = nn.Sequential(*([ PixelNorm() ] + [ EqualLinear(512, 512, activation='fused_lrelu', lr_mul=0.01) for _ in range(8) ]))

        # style-based decoder
        channels = {
            32 : ngf * 8,
            64 : ngf * 8,
            128: ngf * 4,
            256: ngf * 2,
            512: ngf * 1
        }
        self.decoder0 = StyledConv(channels[32], channels[32], 3, 512)
        self.to_rgb0 = ToRGB(channels[32], 512, upsample=False)
        for i in range(4):
            ichan = channels[2 ** (i + 5)]
            ochan = channels[2 ** (i + 6)]
            setattr(self, f'decoder{i + 1}a', StyledConv(ichan, ochan, 3, 512, upsample=True))
            setattr(self, f'decoder{i + 1}b', StyledConv(ochan, ochan, 3, 512))
            setattr(self, f'to_rgb{i + 1}', ToRGB(ochan, 512))
        self.n_latent = 10

        # random style for testing
        self.test_z = torch.randn(1, 512)

        # load pretrained model weights

        if phase == 2:
            # load pretrained encoder and stylegan2 decoder
            self.load_state_dict(model_weights)
        if phase == 3:
            self.clip_mapper = CLIPFeats2Wplus(n_tokens=self.n_latent)
            # load pretraned base model and freeze all params except clip mapper
            self.load_state_dict(model_weights, strict=False)
            params = dict(self.named_parameters())
            for k in params.keys():
                if 'clip_mapper' in k:
                    print(f'{k} not freezed !')
                    continue
                params[k].requires_grad = False

    def get_styles(self, x, **kwargs):
        if len(kwargs) == 0:
            return self.mapping_z(self.test_z.to(x.device).repeat(x.shape[0], 1)).repeat(self.n_latent, 1, 1)
        elif 'mixing' in kwargs and kwargs['mixing']:
            w0 = self.mapping_z(torch.randn(x.shape[0], 512, device=x.device))
            w1 = self.mapping_z(torch.randn(x.shape[0], 512, device=x.device))
            inject_index = random.randint(1, self.n_latent - 1)
            return torch.cat([
                w0.repeat(inject_index, 1, 1),
                w1.repeat(self.n_latent - inject_index, 1, 1)
            ])
        elif 'z' in kwargs:
            return self.mapping_z(kwargs['z']).repeat(self.n_latent, 1, 1)
        elif 'clip_feats' in kwargs:
            return self.clip_mapper(kwargs['clip_feats'])
        else:
            z = torch.randn(x.shape[0], 512, device=x.device)
            return self.mapping_z(z).repeat(self.n_latent, 1, 1)

    def forward(self, x, **kwargs):
        # encode
        feat = self.encoder(x)

        # get style code
        styles = self.get_styles(x, **kwargs)

        # style-based generate
        feat = self.decoder0(feat, styles[0])
        out = self.to_rgb0(feat, styles[1])
        for i in range(4):
            feat = getattr(self, f'decoder{i + 1}a')(feat, styles[i * 2 + 1])
            feat = getattr(self, f'decoder{i + 1}b')(feat, styles[i * 2 + 2])
            out = getattr(self, f'to_rgb{i + 1}')(feat, styles[i * 2 + 3], out)

        return F.hardtanh(out)

def tensor2file(input_image):

    if not isinstance(input_image, np.ndarray):
        if isinstance(input_image, torch.Tensor):  # get the data from a variable
            image_tensor = input_image.data
        else:
            return input_image
        image_numpy = image_tensor[0].cpu().float().numpy()  # convert it into a numpy array
        if image_numpy.shape[0] == 1:  # grayscale to RGB
            image_numpy = np.tile(image_numpy, (3, 1, 1))
        image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0  # post-processing: tranpose and scaling
    else:  # if it is a numpy array, do nothing
        image_numpy = input_image

    if image_numpy.shape[2] <= 3:
        image_numpy = image_numpy.astype(np.uint8)
        image_pil = Image.fromarray(image_numpy)
        return image_pil
    else:
        return image_pil


device = "cuda"
def generate_multi_model(input_img):

    # parse config
    config = parse_config("./exp/sp2pII-phase2.yaml")


    # hard-code some parameters for test
    config['common']['phase'] = "test"
    config['dataset']['n_threads'] = 0   # test code only supports num_threads = 0
    config['dataset']['batch_size'] = 1    # test code only supports batch_size = 1
    config['dataset']['serial_batches'] = True  # disable data shuffling; comment this line if results on randomly chosen images are needed.
    config['dataset']['no_flip'] = True    # no flip; comment this line if results on flipped images are needed.

    # override data augmentation
    config['dataset']['load_size'] = config['testing']['load_size']
    config['dataset']['crop_size'] = config['testing']['crop_size']
    config['dataset']['preprocess'] = config['testing']['preprocess']

    config['training']['pretrained_model'] = "./pretrained_models/phase2_pretrain_90000.pth"
    
    # add testing path
    config['testing']['test_img'] = input_img
    config['testing']['test_video'] = None
    config['testing']['test_folder'] = None

    dataset = SuperDataset(config)
    dataloader = CustomDataLoader(config, dataset)

    model_dict = torch.load("./pretrained_models/phase2_pretrain_90000.pth", map_location='cpu')

    # init netG
    model = Stylizer(ngf=config['model']['ngf'], phase=2, model_weights=model_dict['G_ema_model']).to(device)

    for data in dataloader:

        real_A = data['test_A'].to(device)
        fake_B = model(real_A, mixing=False)
        output_img = tensor2file(fake_B)  # get image results
    
        return output_img
    

def generate_one_shot(src_img, img_prompt):

    # init model
    state_dict = torch.load(f"./checkpoints/{img_prompt[-2:]}/epoch_latest.pth", map_location='cpu')
    model = Stylizer(ngf=64, phase=3, model_weights=state_dict['G_ema_model'])
    model.to(device)
    model.eval()
    model.requires_grad_(False)

    clip_model, img_preprocess = clip.load('ViT-B/32', device=device)
    clip_model.eval()
    clip_model.requires_grad_(False)

    # image transform for stylizer
    img_transform = Compose([
        Resize((512, 512), interpolation=InterpolationMode.LANCZOS),
        ToTensor(),
        Normalize([0.5], [0.5])
    ])

    # get clip features
    with torch.no_grad():
        img = img_preprocess(Image.open(f"./example/reference/{img_prompt[-2:]}.png")).unsqueeze(0).to(device)
        clip_feats = clip_model.encode_image(img)
        clip_feats /= clip_feats.norm(dim=1, keepdim=True)


    # load image & to tensor
    img = Image.open(src_img)
    if not img.mode == 'RGB':
        img = img.convert('RGB')
    img = img_transform(img).unsqueeze(0).to(device)

    # stylize it !
    with torch.no_grad():
        res = model(img, clip_feats=clip_feats)

    output_img = tensor2file(res)  # get image results
    return output_img


def generate_zero_shot(src_img, txt_prompt):
    # init model
    state_dict = torch.load(f"./checkpoints/{txt_prompt.replace(' ', '_')}/epoch_latest.pth", map_location='cpu')
    model = Stylizer(ngf=64, phase=3, model_weights=state_dict['G_ema_model'])
    model.to(device)
    model.eval()
    model.requires_grad_(False)

    clip_model, img_preprocess = clip.load('ViT-B/32', device=device)
    clip_model.eval()
    clip_model.requires_grad_(False)

    # image transform for stylizer
    img_transform = Compose([
        Resize((512, 512), interpolation=InterpolationMode.LANCZOS),
        ToTensor(),
        Normalize([0.5], [0.5])
    ])

    # get clip features
    with torch.no_grad():
        text = clip.tokenize(txt_prompt).to(device)
        clip_feats = clip_model.encode_text(text)
        clip_feats /= clip_feats.norm(dim=1, keepdim=True)


    # load image & to tensor
    img = Image.open(src_img)
    if not img.mode == 'RGB':
        img = img.convert('RGB')
    img = img_transform(img).unsqueeze(0).to(device)

    # stylize it !
    with torch.no_grad():
        res = model(img, clip_feats=clip_feats)

    output_img = tensor2file(res)  # get image results
    return output_img


with gr.Blocks() as demo:
    # 顶部文字
    gr.Markdown("# MMFS")

    # 多个tab
    with gr.Tabs():

        with gr.TabItem("Multi-Model"):
            multi_input_img = gr.Image(label="Upload Input Face Image", type='filepath', height=400)
            gr.Examples(examples=["./example/source/01.png", "./example/source/02.png", "./example/source/03.png", "./example/source/04.png"], inputs=multi_input_img)
            multi_model_button = gr.Button("Random Stylize")
            
            multi_output_img = gr.Image(label="Output Image", height=400)


        with gr.TabItem("One-Shot"):
            one_shot_src_img = gr.Image(label="Upload Input Face Image", type='filepath', height=400)
            gr.Examples(examples=["./example/source/01.png", "./example/source/02.png", "./example/source/03.png", "./example/source/04.png"], inputs=one_shot_src_img)
            with gr.Row():
                gr.Image(shape=(100, 100), value = Image.open("example/reference/01.png"), type='pil', label="ref01")
                gr.Image(shape=(100, 100), value = Image.open("example/reference/02.png"), type='pil', label="ref02")
                gr.Image(shape=(100, 100), value = Image.open("example/reference/03.png"), type='pil', label="ref03")
                gr.Image(shape=(100, 100), value = Image.open("example/reference/04.png"), type='pil', label="ref04")
            
            one_shot_ref_img = gr.Radio(['ref01','ref02','ref03','ref04'],value="ref01", label="Select a reference style image")

            one_shot_test_button = gr.Button("Stylize Image")

            one_shot_output_img = gr.Image(label="Output Image", height=400)

        
        with gr.TabItem("Zero-Shot"):
            zero_shot_src_img = gr.Image(label="Upload Input Face Image", type='filepath', height=400)
            gr.Examples(examples=["./example/source/01.png", "./example/source/02.png", "./example/source/03.png", "./example/source/04.png"], inputs=zero_shot_src_img)
            zero_shot_ref_prompt = gr.Dropdown(
                label="Txt Prompt",
                info="Select a reference style prompt",
                choices=[
                    "pop art",
                    "watercolor painting",
                ],
                max_choices=1,
                value="pop art",
            )

            zero_shot_test_button = gr.Button("Stylize Image")

            zero_shot_output_img = gr.Image(label="Output Image", height=400)


    multi_model_button.click(fn=generate_multi_model, inputs=multi_input_img, outputs=multi_output_img)
    one_shot_test_button.click(fn=generate_one_shot, inputs=[one_shot_src_img, one_shot_ref_img], outputs=one_shot_output_img)
    zero_shot_test_button.click(fn=generate_zero_shot, inputs=[zero_shot_src_img, zero_shot_ref_prompt], outputs=zero_shot_output_img)

demo.queue(max_size=20)
demo.launch()