Spaces:
Sleeping
Sleeping
import gradio as gr | |
from PIL import Image | |
import os | |
import cv2 | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from models.modules.stylegan2.model import StyledConv, ToRGB, EqualLinear, ResBlock, ConvLayer, PixelNorm | |
from utils.util import * | |
from utils.data_utils import Transforms | |
from data import CustomDataLoader | |
from data.super_dataset import SuperDataset | |
from configs import parse_config | |
from utils.augmentation import ImagePathToImage | |
import clip | |
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, InterpolationMode | |
from models.style_based_pix2pixII_model import CLIPFeats2Wplus | |
class Stylizer(nn.Module): | |
def __init__(self, ngf=64, phase=2, model_weights=None): | |
super(Stylizer, self).__init__() | |
# encoder | |
self.encoder = nn.Sequential( | |
ConvLayer(3, ngf, 3), # 512 | |
ResBlock(ngf * 1, ngf * 1), # 256 | |
ResBlock(ngf * 1, ngf * 2), # 128 | |
ResBlock(ngf * 2, ngf * 4), # 64 | |
ResBlock(ngf * 4, ngf * 8), # 32 | |
ConvLayer(ngf * 8, ngf * 8, 3) # 32 | |
) | |
# mapping network | |
self.mapping_z = nn.Sequential(*([ PixelNorm() ] + [ EqualLinear(512, 512, activation='fused_lrelu', lr_mul=0.01) for _ in range(8) ])) | |
# style-based decoder | |
channels = { | |
32 : ngf * 8, | |
64 : ngf * 8, | |
128: ngf * 4, | |
256: ngf * 2, | |
512: ngf * 1 | |
} | |
self.decoder0 = StyledConv(channels[32], channels[32], 3, 512) | |
self.to_rgb0 = ToRGB(channels[32], 512, upsample=False) | |
for i in range(4): | |
ichan = channels[2 ** (i + 5)] | |
ochan = channels[2 ** (i + 6)] | |
setattr(self, f'decoder{i + 1}a', StyledConv(ichan, ochan, 3, 512, upsample=True)) | |
setattr(self, f'decoder{i + 1}b', StyledConv(ochan, ochan, 3, 512)) | |
setattr(self, f'to_rgb{i + 1}', ToRGB(ochan, 512)) | |
self.n_latent = 10 | |
# random style for testing | |
self.test_z = torch.randn(1, 512) | |
# load pretrained model weights | |
if phase == 2: | |
# load pretrained encoder and stylegan2 decoder | |
self.load_state_dict(model_weights) | |
if phase == 3: | |
self.clip_mapper = CLIPFeats2Wplus(n_tokens=self.n_latent) | |
# load pretraned base model and freeze all params except clip mapper | |
self.load_state_dict(model_weights, strict=False) | |
params = dict(self.named_parameters()) | |
for k in params.keys(): | |
if 'clip_mapper' in k: | |
print(f'{k} not freezed !') | |
continue | |
params[k].requires_grad = False | |
def get_styles(self, x, **kwargs): | |
if len(kwargs) == 0: | |
return self.mapping_z(self.test_z.to(x.device).repeat(x.shape[0], 1)).repeat(self.n_latent, 1, 1) | |
elif 'mixing' in kwargs and kwargs['mixing']: | |
w0 = self.mapping_z(torch.randn(x.shape[0], 512, device=x.device)) | |
w1 = self.mapping_z(torch.randn(x.shape[0], 512, device=x.device)) | |
inject_index = random.randint(1, self.n_latent - 1) | |
return torch.cat([ | |
w0.repeat(inject_index, 1, 1), | |
w1.repeat(self.n_latent - inject_index, 1, 1) | |
]) | |
elif 'z' in kwargs: | |
return self.mapping_z(kwargs['z']).repeat(self.n_latent, 1, 1) | |
elif 'clip_feats' in kwargs: | |
return self.clip_mapper(kwargs['clip_feats']) | |
else: | |
z = torch.randn(x.shape[0], 512, device=x.device) | |
return self.mapping_z(z).repeat(self.n_latent, 1, 1) | |
def forward(self, x, **kwargs): | |
# encode | |
feat = self.encoder(x) | |
# get style code | |
styles = self.get_styles(x, **kwargs) | |
# style-based generate | |
feat = self.decoder0(feat, styles[0]) | |
out = self.to_rgb0(feat, styles[1]) | |
for i in range(4): | |
feat = getattr(self, f'decoder{i + 1}a')(feat, styles[i * 2 + 1]) | |
feat = getattr(self, f'decoder{i + 1}b')(feat, styles[i * 2 + 2]) | |
out = getattr(self, f'to_rgb{i + 1}')(feat, styles[i * 2 + 3], out) | |
return F.hardtanh(out) | |
def tensor2file(input_image): | |
if not isinstance(input_image, np.ndarray): | |
if isinstance(input_image, torch.Tensor): # get the data from a variable | |
image_tensor = input_image.data | |
else: | |
return input_image | |
image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array | |
if image_numpy.shape[0] == 1: # grayscale to RGB | |
image_numpy = np.tile(image_numpy, (3, 1, 1)) | |
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling | |
else: # if it is a numpy array, do nothing | |
image_numpy = input_image | |
if image_numpy.shape[2] <= 3: | |
image_numpy = image_numpy.astype(np.uint8) | |
image_pil = Image.fromarray(image_numpy) | |
return image_pil | |
else: | |
return image_pil | |
device = "cuda" | |
def generate_multi_model(input_img): | |
# parse config | |
config = parse_config("./exp/sp2pII-phase2.yaml") | |
# hard-code some parameters for test | |
config['common']['phase'] = "test" | |
config['dataset']['n_threads'] = 0 # test code only supports num_threads = 0 | |
config['dataset']['batch_size'] = 1 # test code only supports batch_size = 1 | |
config['dataset']['serial_batches'] = True # disable data shuffling; comment this line if results on randomly chosen images are needed. | |
config['dataset']['no_flip'] = True # no flip; comment this line if results on flipped images are needed. | |
# override data augmentation | |
config['dataset']['load_size'] = config['testing']['load_size'] | |
config['dataset']['crop_size'] = config['testing']['crop_size'] | |
config['dataset']['preprocess'] = config['testing']['preprocess'] | |
config['training']['pretrained_model'] = "./pretrained_models/phase2_pretrain_90000.pth" | |
# add testing path | |
config['testing']['test_img'] = input_img | |
config['testing']['test_video'] = None | |
config['testing']['test_folder'] = None | |
dataset = SuperDataset(config) | |
dataloader = CustomDataLoader(config, dataset) | |
model_dict = torch.load("./pretrained_models/phase2_pretrain_90000.pth", map_location='cpu') | |
# init netG | |
model = Stylizer(ngf=config['model']['ngf'], phase=2, model_weights=model_dict['G_ema_model']).to(device) | |
for data in dataloader: | |
real_A = data['test_A'].to(device) | |
fake_B = model(real_A, mixing=False) | |
output_img = tensor2file(fake_B) # get image results | |
return output_img | |
def generate_one_shot(src_img, img_prompt): | |
# init model | |
state_dict = torch.load(f"./checkpoints/{img_prompt[-2:]}/epoch_latest.pth", map_location='cpu') | |
model = Stylizer(ngf=64, phase=3, model_weights=state_dict['G_ema_model']) | |
model.to(device) | |
model.eval() | |
model.requires_grad_(False) | |
clip_model, img_preprocess = clip.load('ViT-B/32', device=device) | |
clip_model.eval() | |
clip_model.requires_grad_(False) | |
# image transform for stylizer | |
img_transform = Compose([ | |
Resize((512, 512), interpolation=InterpolationMode.LANCZOS), | |
ToTensor(), | |
Normalize([0.5], [0.5]) | |
]) | |
# get clip features | |
with torch.no_grad(): | |
img = img_preprocess(Image.open(f"./example/reference/{img_prompt[-2:]}.png")).unsqueeze(0).to(device) | |
clip_feats = clip_model.encode_image(img) | |
clip_feats /= clip_feats.norm(dim=1, keepdim=True) | |
# load image & to tensor | |
img = Image.open(src_img) | |
if not img.mode == 'RGB': | |
img = img.convert('RGB') | |
img = img_transform(img).unsqueeze(0).to(device) | |
# stylize it ! | |
with torch.no_grad(): | |
res = model(img, clip_feats=clip_feats) | |
output_img = tensor2file(res) # get image results | |
return output_img | |
def generate_zero_shot(src_img, txt_prompt): | |
# init model | |
state_dict = torch.load(f"./checkpoints/{txt_prompt.replace(' ', '_')}/epoch_latest.pth", map_location='cpu') | |
model = Stylizer(ngf=64, phase=3, model_weights=state_dict['G_ema_model']) | |
model.to(device) | |
model.eval() | |
model.requires_grad_(False) | |
clip_model, img_preprocess = clip.load('ViT-B/32', device=device) | |
clip_model.eval() | |
clip_model.requires_grad_(False) | |
# image transform for stylizer | |
img_transform = Compose([ | |
Resize((512, 512), interpolation=InterpolationMode.LANCZOS), | |
ToTensor(), | |
Normalize([0.5], [0.5]) | |
]) | |
# get clip features | |
with torch.no_grad(): | |
text = clip.tokenize(txt_prompt).to(device) | |
clip_feats = clip_model.encode_text(text) | |
clip_feats /= clip_feats.norm(dim=1, keepdim=True) | |
# load image & to tensor | |
img = Image.open(src_img) | |
if not img.mode == 'RGB': | |
img = img.convert('RGB') | |
img = img_transform(img).unsqueeze(0).to(device) | |
# stylize it ! | |
with torch.no_grad(): | |
res = model(img, clip_feats=clip_feats) | |
output_img = tensor2file(res) # get image results | |
return output_img | |
with gr.Blocks() as demo: | |
# 顶部文字 | |
gr.Markdown("# MMFS") | |
# 多个tab | |
with gr.Tabs(): | |
with gr.TabItem("Multi-Model"): | |
multi_input_img = gr.Image(label="Upload Input Face Image", type='filepath', height=400) | |
gr.Examples(examples=["./example/source/01.png", "./example/source/02.png", "./example/source/03.png", "./example/source/04.png"], inputs=multi_input_img) | |
multi_model_button = gr.Button("Random Stylize") | |
multi_output_img = gr.Image(label="Output Image", height=400) | |
with gr.TabItem("One-Shot"): | |
one_shot_src_img = gr.Image(label="Upload Input Face Image", type='filepath', height=400) | |
gr.Examples(examples=["./example/source/01.png", "./example/source/02.png", "./example/source/03.png", "./example/source/04.png"], inputs=one_shot_src_img) | |
with gr.Row(): | |
gr.Image(shape=(100, 100), value = Image.open("example/reference/01.png"), type='pil', label="ref01") | |
gr.Image(shape=(100, 100), value = Image.open("example/reference/02.png"), type='pil', label="ref02") | |
gr.Image(shape=(100, 100), value = Image.open("example/reference/03.png"), type='pil', label="ref03") | |
gr.Image(shape=(100, 100), value = Image.open("example/reference/04.png"), type='pil', label="ref04") | |
one_shot_ref_img = gr.Radio(['ref01','ref02','ref03','ref04'],value="ref01", label="Select a reference style image") | |
one_shot_test_button = gr.Button("Stylize Image") | |
one_shot_output_img = gr.Image(label="Output Image", height=400) | |
with gr.TabItem("Zero-Shot"): | |
zero_shot_src_img = gr.Image(label="Upload Input Face Image", type='filepath', height=400) | |
gr.Examples(examples=["./example/source/01.png", "./example/source/02.png", "./example/source/03.png", "./example/source/04.png"], inputs=zero_shot_src_img) | |
zero_shot_ref_prompt = gr.Dropdown( | |
label="Txt Prompt", | |
info="Select a reference style prompt", | |
choices=[ | |
"pop art", | |
"watercolor painting", | |
], | |
max_choices=1, | |
value="pop art", | |
) | |
zero_shot_test_button = gr.Button("Stylize Image") | |
zero_shot_output_img = gr.Image(label="Output Image", height=400) | |
multi_model_button.click(fn=generate_multi_model, inputs=multi_input_img, outputs=multi_output_img) | |
one_shot_test_button.click(fn=generate_one_shot, inputs=[one_shot_src_img, one_shot_ref_img], outputs=one_shot_output_img) | |
zero_shot_test_button.click(fn=generate_zero_shot, inputs=[zero_shot_src_img, zero_shot_ref_prompt], outputs=zero_shot_output_img) | |
demo.queue(max_size=20) | |
demo.launch() |