import argparse import os import numpy as np from PIL import Image from tqdm import tqdm import clip import torch from torchvision.transforms import Compose, Resize, ToTensor, Normalize, InterpolationMode from models.style_based_pix2pixII_model import Stylizer, TrainingPhase if __name__ == '__main__': # define & parse args parser = argparse.ArgumentParser(description='sp2pII test') parser.add_argument('--ckpt', type=str, default='./checkpoints/01/epoch_latest.pth') parser.add_argument('--in_folder', type=str, default='./example/source') parser.add_argument('--out_folder', type=str, default='./example/outputs/one-shot') parser.add_argument('--phase', type=int, default=3) parser.add_argument('--txt_prompt', type=str, default='sketch with black pencil') parser.add_argument('--img_prompt', type=str, default='./example/reference/01.png') parser.add_argument('--device', type=str, default='cuda:0') args = parser.parse_args() args.phase = TrainingPhase(args.phase) os.makedirs(args.out_folder, exist_ok=True) # init model state_dict = torch.load(args.ckpt, map_location='cpu') model = Stylizer(ngf=64, phase=args.phase, model_weights=state_dict['G_ema_model']) model.to(args.device) model.eval() model.requires_grad_(False) clip_model, img_preprocess = clip.load('ViT-B/32', device=args.device) clip_model.eval() clip_model.requires_grad_(False) # image transform for stylizer img_transform = Compose([ Resize((512, 512), interpolation=InterpolationMode.LANCZOS), ToTensor(), Normalize([0.5], [0.5]) ]) # get clip features with torch.no_grad(): if os.path.isfile(args.img_prompt): img = img_preprocess(Image.open(args.img_prompt)).unsqueeze(0).to(args.device) clip_feats = clip_model.encode_image(img) else: text = clip.tokenize(args.txt_prompt).to(args.device) clip_feats = clip_model.encode_text(text) clip_feats /= clip_feats.norm(dim=1, keepdim=True) # enum image files files = os.listdir(args.in_folder) for fn in tqdm(files): prefix, ext = os.path.splitext(fn) if not ext.lower() in ['.png', '.jpg', '.jpeg']: continue # load image & to tensor img = Image.open(os.path.join(args.in_folder, fn)) if not img.mode == 'RGB': img = img.convert('RGB') img = img_transform(img).unsqueeze(0).to(args.device) # stylize it ! with torch.no_grad(): if args.phase == TrainingPhase.CLIP_MAPPING: res = model(img, clip_feats=clip_feats) # save image res = res.cpu().numpy()[0] res = np.transpose(res, (1, 2, 0)) * 0.5 + 0.5 Image.fromarray((res * 255).astype(np.uint8)).save(os.path.join(args.out_folder, prefix + '.png'))