import gradio as gr
from PIL import Image
import os
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.modules.stylegan2.model import StyledConv, ToRGB, EqualLinear, ResBlock, ConvLayer, PixelNorm
from utils.util import *
from utils.data_utils import Transforms
from data import CustomDataLoader
from data.super_dataset import SuperDataset
from configs import parse_config
from utils.augmentation import ImagePathToImage
import clip
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, InterpolationMode
from models.style_based_pix2pixII_model import CLIPFeats2Wplus
class Stylizer(nn.Module):
def __init__(self, ngf=64, phase=2, model_weights=None):
super(Stylizer, self).__init__()
# encoder
self.encoder = nn.Sequential(
ConvLayer(3, ngf, 3), # 512
ResBlock(ngf * 1, ngf * 1), # 256
ResBlock(ngf * 1, ngf * 2), # 128
ResBlock(ngf * 2, ngf * 4), # 64
ResBlock(ngf * 4, ngf * 8), # 32
ConvLayer(ngf * 8, ngf * 8, 3) # 32
# mapping network
self.mapping_z = nn.Sequential(*([ PixelNorm() ] + [ EqualLinear(512, 512, activation='fused_lrelu', lr_mul=0.01) for _ in range(8) ]))
# style-based decoder
channels = {
32 : ngf * 8,
64 : ngf * 8,
128: ngf * 4,
256: ngf * 2,
512: ngf * 1
self.decoder0 = StyledConv(channels[32], channels[32], 3, 512)
self.to_rgb0 = ToRGB(channels[32], 512, upsample=False)
for i in range(4):
ichan = channels[2 ** (i + 5)]
ochan = channels[2 ** (i + 6)]
setattr(self, f'decoder{i + 1}a', StyledConv(ichan, ochan, 3, 512, upsample=True))
setattr(self, f'decoder{i + 1}b', StyledConv(ochan, ochan, 3, 512))
setattr(self, f'to_rgb{i + 1}', ToRGB(ochan, 512))
self.n_latent = 10
# random style for testing
self.test_z = torch.randn(1, 512)
# load pretrained model weights
if phase == 2:
# load pretrained encoder and stylegan2 decoder
if phase == 3:
self.clip_mapper = CLIPFeats2Wplus(n_tokens=self.n_latent)
# load pretraned base model and freeze all params except clip mapper
self.load_state_dict(model_weights, strict=False)
params = dict(self.named_parameters())
for k in params.keys():
if 'clip_mapper' in k:
print(f'{k} not freezed !')
params[k].requires_grad = False
def get_styles(self, x, **kwargs):
if len(kwargs) == 0:
return self.mapping_z([0], 1)).repeat(self.n_latent, 1, 1)
elif 'mixing' in kwargs and kwargs['mixing']:
w0 = self.mapping_z(torch.randn(x.shape[0], 512, device=x.device))
w1 = self.mapping_z(torch.randn(x.shape[0], 512, device=x.device))
inject_index = random.randint(1, self.n_latent - 1)
w0.repeat(inject_index, 1, 1),
w1.repeat(self.n_latent - inject_index, 1, 1)
elif 'z' in kwargs:
return self.mapping_z(kwargs['z']).repeat(self.n_latent, 1, 1)
elif 'clip_feats' in kwargs:
return self.clip_mapper(kwargs['clip_feats'])
z = torch.randn(x.shape[0], 512, device=x.device)
return self.mapping_z(z).repeat(self.n_latent, 1, 1)
def forward(self, x, **kwargs):
# encode
feat = self.encoder(x)
# get style code
styles = self.get_styles(x, **kwargs)
# style-based generate
feat = self.decoder0(feat, styles[0])
out = self.to_rgb0(feat, styles[1])
for i in range(4):
feat = getattr(self, f'decoder{i + 1}a')(feat, styles[i * 2 + 1])
feat = getattr(self, f'decoder{i + 1}b')(feat, styles[i * 2 + 2])
out = getattr(self, f'to_rgb{i + 1}')(feat, styles[i * 2 + 3], out)
return F.hardtanh(out)
def tensor2file(input_image):
if not isinstance(input_image, np.ndarray):
if isinstance(input_image, torch.Tensor): # get the data from a variable
image_tensor =
return input_image
image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array
if image_numpy.shape[0] == 1: # grayscale to RGB
image_numpy = np.tile(image_numpy, (3, 1, 1))
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
else: # if it is a numpy array, do nothing
image_numpy = input_image
if image_numpy.shape[2] <= 3:
image_numpy = image_numpy.astype(np.uint8)
image_pil = Image.fromarray(image_numpy)
return image_pil
return image_pil
device = "cuda"
def generate_multi_model(input_img):
# parse config
config = parse_config("./exp/sp2pII-phase2.yaml")
# hard-code some parameters for test
config['common']['phase'] = "test"
config['dataset']['n_threads'] = 0 # test code only supports num_threads = 0
config['dataset']['batch_size'] = 1 # test code only supports batch_size = 1
config['dataset']['serial_batches'] = True # disable data shuffling; comment this line if results on randomly chosen images are needed.
config['dataset']['no_flip'] = True # no flip; comment this line if results on flipped images are needed.
# override data augmentation
config['dataset']['load_size'] = config['testing']['load_size']
config['dataset']['crop_size'] = config['testing']['crop_size']
config['dataset']['preprocess'] = config['testing']['preprocess']
config['training']['pretrained_model'] = "./pretrained_models/phase2_pretrain_90000.pth"
# add testing path
config['testing']['test_img'] = input_img
config['testing']['test_video'] = None
config['testing']['test_folder'] = None
dataset = SuperDataset(config)
dataloader = CustomDataLoader(config, dataset)
model_dict = torch.load("./pretrained_models/phase2_pretrain_90000.pth", map_location='cpu')
# init netG
model = Stylizer(ngf=config['model']['ngf'], phase=2, model_weights=model_dict['G_ema_model']).to(device)
for data in dataloader:
real_A = data['test_A'].to(device)
fake_B = model(real_A, mixing=False)
output_img = tensor2file(fake_B) # get image results
return output_img
def generate_one_shot(src_img, img_prompt):
# init model
state_dict = torch.load(f"./checkpoints/{img_prompt[-2:]}/epoch_latest.pth", map_location='cpu')
model = Stylizer(ngf=64, phase=3, model_weights=state_dict['G_ema_model'])
clip_model, img_preprocess = clip.load('ViT-B/32', device=device)
# image transform for stylizer
img_transform = Compose([
Resize((512, 512), interpolation=InterpolationMode.LANCZOS),
Normalize([0.5], [0.5])
# get clip features
with torch.no_grad():
img = img_preprocess("./example/reference/{img_prompt[-2:]}.png")).unsqueeze(0).to(device)
clip_feats = clip_model.encode_image(img)
clip_feats /= clip_feats.norm(dim=1, keepdim=True)
# load image & to tensor
img =
if not img.mode == 'RGB':
img = img.convert('RGB')
img = img_transform(img).unsqueeze(0).to(device)
# stylize it !
with torch.no_grad():
res = model(img, clip_feats=clip_feats)
output_img = tensor2file(res) # get image results
return output_img
def generate_zero_shot(src_img, txt_prompt):
# init model
state_dict = torch.load(f"./checkpoints/{txt_prompt.replace(' ', '_')}/epoch_latest.pth", map_location='cpu')
model = Stylizer(ngf=64, phase=3, model_weights=state_dict['G_ema_model'])
clip_model, img_preprocess = clip.load('ViT-B/32', device=device)
# image transform for stylizer
img_transform = Compose([
Resize((512, 512), interpolation=InterpolationMode.LANCZOS),
Normalize([0.5], [0.5])
# get clip features
with torch.no_grad():
text = clip.tokenize(txt_prompt).to(device)
clip_feats = clip_model.encode_text(text)
clip_feats /= clip_feats.norm(dim=1, keepdim=True)
# load image & to tensor
img =
if not img.mode == 'RGB':
img = img.convert('RGB')
img = img_transform(img).unsqueeze(0).to(device)
# stylize it !
with torch.no_grad():
res = model(img, clip_feats=clip_feats)
output_img = tensor2file(res) # get image results
return output_img
with gr.Blocks() as demo:
# 顶部文字
gr.Markdown("# MMFS")
# 多个tab
with gr.Tabs():
with gr.TabItem("Multi-Model"):
multi_input_img = gr.Image(label="Upload Input Face Image", type='filepath', height=400)
gr.Examples(examples=["./example/source/01.png", "./example/source/02.png", "./example/source/03.png", "./example/source/04.png"], inputs=multi_input_img)
multi_model_button = gr.Button("Random Stylize")
multi_output_img = gr.Image(label="Output Image", height=400)
with gr.TabItem("One-Shot"):
one_shot_src_img = gr.Image(label="Upload Input Face Image", type='filepath', height=400)
gr.Examples(examples=["./example/source/01.png", "./example/source/02.png", "./example/source/03.png", "./example/source/04.png"], inputs=one_shot_src_img)
with gr.Row():
gr.Image(shape=(100, 100), value ="example/reference/01.png"), type='pil', label="ref01")
gr.Image(shape=(100, 100), value ="example/reference/02.png"), type='pil', label="ref02")
gr.Image(shape=(100, 100), value ="example/reference/03.png"), type='pil', label="ref03")
gr.Image(shape=(100, 100), value ="example/reference/04.png"), type='pil', label="ref04")
one_shot_ref_img = gr.Radio(['ref01','ref02','ref03','ref04'],value="ref01", label="Select a reference style image")
one_shot_test_button = gr.Button("Stylize Image")
one_shot_output_img = gr.Image(label="Output Image", height=400)
with gr.TabItem("Zero-Shot"):
zero_shot_src_img = gr.Image(label="Upload Input Face Image", type='filepath', height=400)
gr.Examples(examples=["./example/source/01.png", "./example/source/02.png", "./example/source/03.png", "./example/source/04.png"], inputs=zero_shot_src_img)
zero_shot_ref_prompt = gr.Dropdown(
label="Txt Prompt",
info="Select a reference style prompt",
"pop art",
"watercolor painting",
value="pop art",
zero_shot_test_button = gr.Button("Stylize Image")
zero_shot_output_img = gr.Image(label="Output Image", height=400), inputs=multi_input_img, outputs=multi_output_img), inputs=[one_shot_src_img, one_shot_ref_img], outputs=one_shot_output_img), inputs=[zero_shot_src_img, zero_shot_ref_prompt], outputs=zero_shot_output_img)