import gradio as gr from PIL import Image import os import cv2 import torch import torch.nn as nn import torch.nn.functional as F from models.modules.stylegan2.model import StyledConv, ToRGB, EqualLinear, ResBlock, ConvLayer, PixelNorm from utils.util import * from utils.data_utils import Transforms from data import CustomDataLoader from data.super_dataset import SuperDataset from configs import parse_config from utils.augmentation import ImagePathToImage import clip from torchvision.transforms import Compose, Resize, ToTensor, Normalize, InterpolationMode from models.style_based_pix2pixII_model import CLIPFeats2Wplus class Stylizer(nn.Module): def __init__(self, ngf=64, phase=2, model_weights=None): super(Stylizer, self).__init__() # encoder self.encoder = nn.Sequential( ConvLayer(3, ngf, 3), # 512 ResBlock(ngf * 1, ngf * 1), # 256 ResBlock(ngf * 1, ngf * 2), # 128 ResBlock(ngf * 2, ngf * 4), # 64 ResBlock(ngf * 4, ngf * 8), # 32 ConvLayer(ngf * 8, ngf * 8, 3) # 32 ) # mapping network self.mapping_z = nn.Sequential(*([ PixelNorm() ] + [ EqualLinear(512, 512, activation='fused_lrelu', lr_mul=0.01) for _ in range(8) ])) # style-based decoder channels = { 32 : ngf * 8, 64 : ngf * 8, 128: ngf * 4, 256: ngf * 2, 512: ngf * 1 } self.decoder0 = StyledConv(channels[32], channels[32], 3, 512) self.to_rgb0 = ToRGB(channels[32], 512, upsample=False) for i in range(4): ichan = channels[2 ** (i + 5)] ochan = channels[2 ** (i + 6)] setattr(self, f'decoder{i + 1}a', StyledConv(ichan, ochan, 3, 512, upsample=True)) setattr(self, f'decoder{i + 1}b', StyledConv(ochan, ochan, 3, 512)) setattr(self, f'to_rgb{i + 1}', ToRGB(ochan, 512)) self.n_latent = 10 # random style for testing self.test_z = torch.randn(1, 512) # load pretrained model weights if phase == 2: # load pretrained encoder and stylegan2 decoder self.load_state_dict(model_weights) if phase == 3: self.clip_mapper = CLIPFeats2Wplus(n_tokens=self.n_latent) # load pretraned base model and freeze all params except clip mapper self.load_state_dict(model_weights, strict=False) params = dict(self.named_parameters()) for k in params.keys(): if 'clip_mapper' in k: print(f'{k} not freezed !') continue params[k].requires_grad = False def get_styles(self, x, **kwargs): if len(kwargs) == 0: return self.mapping_z(self.test_z.to(x.device).repeat(x.shape[0], 1)).repeat(self.n_latent, 1, 1) elif 'mixing' in kwargs and kwargs['mixing']: w0 = self.mapping_z(torch.randn(x.shape[0], 512, device=x.device)) w1 = self.mapping_z(torch.randn(x.shape[0], 512, device=x.device)) inject_index = random.randint(1, self.n_latent - 1) return torch.cat([ w0.repeat(inject_index, 1, 1), w1.repeat(self.n_latent - inject_index, 1, 1) ]) elif 'z' in kwargs: return self.mapping_z(kwargs['z']).repeat(self.n_latent, 1, 1) elif 'clip_feats' in kwargs: return self.clip_mapper(kwargs['clip_feats']) else: z = torch.randn(x.shape[0], 512, device=x.device) return self.mapping_z(z).repeat(self.n_latent, 1, 1) def forward(self, x, **kwargs): # encode feat = self.encoder(x) # get style code styles = self.get_styles(x, **kwargs) # style-based generate feat = self.decoder0(feat, styles[0]) out = self.to_rgb0(feat, styles[1]) for i in range(4): feat = getattr(self, f'decoder{i + 1}a')(feat, styles[i * 2 + 1]) feat = getattr(self, f'decoder{i + 1}b')(feat, styles[i * 2 + 2]) out = getattr(self, f'to_rgb{i + 1}')(feat, styles[i * 2 + 3], out) return F.hardtanh(out) def tensor2file(input_image): if not isinstance(input_image, np.ndarray): if isinstance(input_image, torch.Tensor): # get the data from a variable image_tensor = input_image.data else: return input_image image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array if image_numpy.shape[0] == 1: # grayscale to RGB image_numpy = np.tile(image_numpy, (3, 1, 1)) image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling else: # if it is a numpy array, do nothing image_numpy = input_image if image_numpy.shape[2] <= 3: image_numpy = image_numpy.astype(np.uint8) image_pil = Image.fromarray(image_numpy) return image_pil else: return image_pil device = "cuda" def generate_multi_model(input_img): # parse config config = parse_config("./exp/sp2pII-phase2.yaml") # hard-code some parameters for test config['common']['phase'] = "test" config['dataset']['n_threads'] = 0 # test code only supports num_threads = 0 config['dataset']['batch_size'] = 1 # test code only supports batch_size = 1 config['dataset']['serial_batches'] = True # disable data shuffling; comment this line if results on randomly chosen images are needed. config['dataset']['no_flip'] = True # no flip; comment this line if results on flipped images are needed. # override data augmentation config['dataset']['load_size'] = config['testing']['load_size'] config['dataset']['crop_size'] = config['testing']['crop_size'] config['dataset']['preprocess'] = config['testing']['preprocess'] config['training']['pretrained_model'] = "./pretrained_models/phase2_pretrain_90000.pth" # add testing path config['testing']['test_img'] = input_img config['testing']['test_video'] = None config['testing']['test_folder'] = None dataset = SuperDataset(config) dataloader = CustomDataLoader(config, dataset) model_dict = torch.load("./pretrained_models/phase2_pretrain_90000.pth", map_location='cpu') # init netG model = Stylizer(ngf=config['model']['ngf'], phase=2, model_weights=model_dict['G_ema_model']).to(device) for data in dataloader: real_A = data['test_A'].to(device) fake_B = model(real_A, mixing=False) output_img = tensor2file(fake_B) # get image results return output_img def generate_one_shot(src_img, img_prompt): # init model state_dict = torch.load(f"./checkpoints/{img_prompt[-2:]}/epoch_latest.pth", map_location='cpu') model = Stylizer(ngf=64, phase=3, model_weights=state_dict['G_ema_model']) model.to(device) model.eval() model.requires_grad_(False) clip_model, img_preprocess = clip.load('ViT-B/32', device=device) clip_model.eval() clip_model.requires_grad_(False) # image transform for stylizer img_transform = Compose([ Resize((512, 512), interpolation=InterpolationMode.LANCZOS), ToTensor(), Normalize([0.5], [0.5]) ]) # get clip features with torch.no_grad(): img = img_preprocess(Image.open(f"./example/reference/{img_prompt[-2:]}.png")).unsqueeze(0).to(device) clip_feats = clip_model.encode_image(img) clip_feats /= clip_feats.norm(dim=1, keepdim=True) # load image & to tensor img = Image.open(src_img) if not img.mode == 'RGB': img = img.convert('RGB') img = img_transform(img).unsqueeze(0).to(device) # stylize it ! with torch.no_grad(): res = model(img, clip_feats=clip_feats) output_img = tensor2file(res) # get image results return output_img def generate_zero_shot(src_img, txt_prompt): # init model state_dict = torch.load(f"./checkpoints/{txt_prompt.replace(' ', '_')}/epoch_latest.pth", map_location='cpu') model = Stylizer(ngf=64, phase=3, model_weights=state_dict['G_ema_model']) model.to(device) model.eval() model.requires_grad_(False) clip_model, img_preprocess = clip.load('ViT-B/32', device=device) clip_model.eval() clip_model.requires_grad_(False) # image transform for stylizer img_transform = Compose([ Resize((512, 512), interpolation=InterpolationMode.LANCZOS), ToTensor(), Normalize([0.5], [0.5]) ]) # get clip features with torch.no_grad(): text = clip.tokenize(txt_prompt).to(device) clip_feats = clip_model.encode_text(text) clip_feats /= clip_feats.norm(dim=1, keepdim=True) # load image & to tensor img = Image.open(src_img) if not img.mode == 'RGB': img = img.convert('RGB') img = img_transform(img).unsqueeze(0).to(device) # stylize it ! with torch.no_grad(): res = model(img, clip_feats=clip_feats) output_img = tensor2file(res) # get image results return output_img with gr.Blocks() as demo: # 顶部文字 gr.Markdown("# MMFS") # 多个tab with gr.Tabs(): with gr.TabItem("Multi-Model"): multi_input_img = gr.Image(label="Upload Input Face Image", type='filepath', height=400) gr.Examples(examples=["./example/source/01.png", "./example/source/02.png", "./example/source/03.png", "./example/source/04.png"], inputs=multi_input_img) multi_model_button = gr.Button("Random Stylize") multi_output_img = gr.Image(label="Output Image", height=400) with gr.TabItem("One-Shot"): one_shot_src_img = gr.Image(label="Upload Input Face Image", type='filepath', height=400) gr.Examples(examples=["./example/source/01.png", "./example/source/02.png", "./example/source/03.png", "./example/source/04.png"], inputs=one_shot_src_img) with gr.Row(): gr.Image(shape=(100, 100), value = Image.open("example/reference/01.png"), type='pil', label="ref01") gr.Image(shape=(100, 100), value = Image.open("example/reference/02.png"), type='pil', label="ref02") gr.Image(shape=(100, 100), value = Image.open("example/reference/03.png"), type='pil', label="ref03") gr.Image(shape=(100, 100), value = Image.open("example/reference/04.png"), type='pil', label="ref04") one_shot_ref_img = gr.Radio(['ref01','ref02','ref03','ref04'],value="ref01", label="Select a reference style image") one_shot_test_button = gr.Button("Stylize Image") one_shot_output_img = gr.Image(label="Output Image", height=400) with gr.TabItem("Zero-Shot"): zero_shot_src_img = gr.Image(label="Upload Input Face Image", type='filepath', height=400) gr.Examples(examples=["./example/source/01.png", "./example/source/02.png", "./example/source/03.png", "./example/source/04.png"], inputs=zero_shot_src_img) zero_shot_ref_prompt = gr.Dropdown( label="Txt Prompt", info="Select a reference style prompt", choices=[ "pop art", "watercolor painting", ], max_choices=1, value="pop art", ) zero_shot_test_button = gr.Button("Stylize Image") zero_shot_output_img = gr.Image(label="Output Image", height=400) multi_model_button.click(fn=generate_multi_model, inputs=multi_input_img, outputs=multi_output_img) one_shot_test_button.click(fn=generate_one_shot, inputs=[one_shot_src_img, one_shot_ref_img], outputs=one_shot_output_img) zero_shot_test_button.click(fn=generate_zero_shot, inputs=[zero_shot_src_img, zero_shot_ref_prompt], outputs=zero_shot_output_img) demo.queue(max_size=20) demo.launch()