import os import cv2 import time import random import numpy as np from PIL import Image import torch import torchvision.transforms as transforms from accelerate.utils import set_seed from src import (FontDiffuserDPMPipeline, FontDiffuserModelDPM, build_ddpm_scheduler, build_unet, build_content_encoder, build_style_encoder) from utils import (ttf2im, load_ttf, is_char_in_font, save_args_to_yaml, save_single_image, save_image_with_content_style) def arg_parse(): from configs.fontdiffuser import get_parser parser = get_parser() parser.add_argument("--ckpt_dir", type=str, default=None) parser.add_argument("--demo", action="store_true") parser.add_argument("--controlnet", type=bool, default=False, help="If in demo mode, the controlnet can be added.") parser.add_argument("--character_input", action="store_true") parser.add_argument("--content_character", type=str, default=None) parser.add_argument("--content_image_path", type=str, default=None) parser.add_argument("--style_image_path", type=str, default=None) parser.add_argument("--save_image", action="store_true") parser.add_argument("--save_image_dir", type=str, default=None, help="The saving directory.") parser.add_argument("--device", type=str, default="cuda:0") parser.add_argument("--ttf_path", type=str, default="ttf/KaiXinSongA.ttf") args = parser.parse_args() style_image_size = args.style_image_size content_image_size = args.content_image_size args.style_image_size = (style_image_size, style_image_size) args.content_image_size = (content_image_size, content_image_size) return args def image_process(args, content_image=None, style_image=None): if not args.demo: # Read content image and style image if args.character_input: assert args.content_character is not None, "The content_character should not be None." if not is_char_in_font(font_path=args.ttf_path, char=args.content_character): return None, None font = load_ttf(ttf_path=args.ttf_path) content_image = ttf2im(font=font, char=args.content_character) content_image_pil = content_image.copy() else: content_image = Image.open(args.content_image_path).convert('RGB') content_image_pil = None style_image = Image.open(args.style_image_path).convert('RGB') else: assert style_image is not None, "The style image should not be None." if args.character_input: assert args.content_character is not None, "The content_character should not be None." if not is_char_in_font(font_path=args.ttf_path, char=args.content_character): return None, None font = load_ttf(ttf_path=args.ttf_path) content_image = ttf2im(font=font, char=args.content_character) else: assert content_image is not None, "The content image should not be None." content_image_pil = None ## Dataset transform content_inference_transforms = transforms.Compose( [transforms.Resize(args.content_image_size, \ interpolation=transforms.InterpolationMode.BILINEAR), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) style_inference_transforms = transforms.Compose( [transforms.Resize(args.style_image_size, \ interpolation=transforms.InterpolationMode.BILINEAR), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) content_image = content_inference_transforms(content_image)[None, :] style_image = style_inference_transforms(style_image)[None, :] return content_image, style_image, content_image_pil def load_fontdiffuer_pipeline(args): # Load the model state_dict unet = build_unet(args=args) unet.load_state_dict(torch.load(f"{args.ckpt_dir}/unet.pth")) style_encoder = build_style_encoder(args=args) style_encoder.load_state_dict(torch.load(f"{args.ckpt_dir}/style_encoder.pth")) content_encoder = build_content_encoder(args=args) content_encoder.load_state_dict(torch.load(f"{args.ckpt_dir}/content_encoder.pth")) model = FontDiffuserModelDPM( unet=unet, style_encoder=style_encoder, content_encoder=content_encoder) model.to(args.device) print("Loaded the model state_dict successfully!") # Load the training ddpm_scheduler. train_scheduler = build_ddpm_scheduler(args=args) print("Loaded training DDPM scheduler sucessfully!") # Load the DPM_Solver to generate the sample. pipe = FontDiffuserDPMPipeline( model=model, ddpm_train_scheduler=train_scheduler, model_type=args.model_type, guidance_type=args.guidance_type, guidance_scale=args.guidance_scale, ) print("Loaded dpm_solver pipeline sucessfully!") return pipe def sampling(args, pipe, content_image=None, style_image=None): if not args.demo: os.makedirs(args.save_image_dir, exist_ok=True) # saving sampling config save_args_to_yaml(args=args, output_file=f"{args.save_image_dir}/sampling_config.yaml") if args.seed: set_seed(seed=args.seed) content_image, style_image, content_image_pil = image_process(args=args, content_image=content_image, style_image=style_image) if content_image == None: print(f"The content_character you provided is not in the ttf. \ Please change the content_character or you can change the ttf.") return None with torch.no_grad(): content_image = content_image.to(args.device) style_image = style_image.to(args.device) print(f"Sampling by DPM-Solver++ ......") start = time.time() images = pipe.generate( content_images=content_image, style_images=style_image, batch_size=1, order=args.order, num_inference_step=args.num_inference_steps, content_encoder_downsample_size=args.content_encoder_downsample_size, t_start=args.t_start, t_end=args.t_end, dm_size=args.content_image_size, algorithm_type=args.algorithm_type, skip_type=args.skip_type, method=args.method, correcting_x0_fn=args.correcting_x0_fn) end = time.time() if args.save_image: print(f"Saving the image ......") save_single_image(save_dir=args.save_image_dir, image=images[0]) if args.character_input: save_image_with_content_style(save_dir=args.save_image_dir, image=images[0], content_image_pil=content_image_pil, content_image_path=None, style_image_path=args.style_image_path, resolution=args.resolution) else: save_image_with_content_style(save_dir=args.save_image_dir, image=images[0], content_image_pil=None, content_image_path=args.content_image_path, style_image_path=args.style_image_path, resolution=args.resolution) print(f"Finish the sampling process, costing time {end - start}s") return images[0] def load_controlnet_pipeline(args, config_path="lllyasviel/sd-controlnet-canny", ckpt_path="runwayml/stable-diffusion-v1-5"): from diffusers import ControlNetModel, AutoencoderKL # load controlnet model and pipeline from diffusers import StableDiffusionControlNetPipeline, UniPCMultistepScheduler controlnet = ControlNetModel.from_pretrained(config_path, torch_dtype=torch.float16, cache_dir=f"{args.ckpt_dir}/controlnet") print(f"Loaded ControlNet Model Successfully!") pipe = StableDiffusionControlNetPipeline.from_pretrained(ckpt_path, controlnet=controlnet, torch_dtype=torch.float16, cache_dir=f"{args.ckpt_dir}/controlnet_pipeline") # faster pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) pipe.enable_model_cpu_offload() print(f"Loaded ControlNet Pipeline Successfully!") return pipe def controlnet(text_prompt, pil_image, pipe): image = np.array(pil_image) # get canny image image = cv2.Canny(image=image, threshold1=100, threshold2=200) image = image[:, :, None] image = np.concatenate([image, image, image], axis=2) canny_image = Image.fromarray(image) seed = random.randint(0, 10000) generator = torch.manual_seed(seed) image = pipe(text_prompt, num_inference_steps=50, generator=generator, image=canny_image, output_type='pil').images[0] return image def load_instructpix2pix_pipeline(args, ckpt_path="timbrooks/instruct-pix2pix"): from diffusers import StableDiffusionInstructPix2PixPipeline, EulerAncestralDiscreteScheduler pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(ckpt_path, torch_dtype=torch.float16) pipe.to(args.device) pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) return pipe def instructpix2pix(pil_image, text_prompt, pipe): image = pil_image.resize((512, 512)) seed = random.randint(0, 10000) generator = torch.manual_seed(seed) image = pipe(prompt=text_prompt, image=image, generator=generator, num_inference_steps=20, image_guidance_scale=1.1).images[0] return image if __name__=="__main__": args = arg_parse() # load fontdiffuser pipeline pipe = load_fontdiffuer_pipeline(args=args) out_image = sampling(args=args, pipe=pipe)