File size: 7,338 Bytes
332190f 8a7587e 332190f 8a7587e 332190f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
import argparse, os, time
import torch
from diffusers import (
AutoencoderKL,
ControlNetModel,
StableDiffusionControlNetPipeline,
UNet2DConditionModel,
UniPCMultistepScheduler,
PNDMScheduler,
AmusedPipeline, AmusedScheduler, VQModel, UVit2DModel
)
from transformers import AutoTokenizer, CLIPFeatureExtractor
from diffusers.pipelines.deprecated.alt_diffusion import RobertaSeriesModelWithTransformation
from diffusers.utils import load_image
from utils.mclip import *
def parse_args():
parser = argparse.ArgumentParser(description="Generate images with M3Face.")
parser.add_argument(
"--prompt",
type=str,
default="This attractive woman has narrow eyes, rosy cheeks, and wears heavy makeup.",
help="The input text prompt for image generation."
)
parser.add_argument(
"--condition",
type=str,
default="mask",
choices=["mask", "landmark"],
help="Use segmentation mask or facial landmarks for image generation."
)
parser.add_argument(
"--condition_path",
type=str,
default=None,
help="Path to the condition mask/landmark image. We will generate the condition if it is not given."
)
parser.add_argument("--save_condition", action="store_true", help="Save the generated condition image.")
parser.add_argument("--use_english", action="store_true", help="Use the English models.")
parser.add_argument("--enhance_prompt", action="store_true", help="Enhance the given text prompt.")
parser.add_argument("--num_inference_steps", type=int, default=30)
parser.add_argument("--num_samples", type=int, default=1)
parser.add_argument(
"--additional_prompt",
type=str,
default="rim lighting, dslr, ultra quality, sharp focus, dof, Fujifilm XT3, crystal clear, highly detailed glossy eyes, high detailed skin, skin pores, 8K UHD"
)
parser.add_argument(
"--negative_prompt",
type=str,
default="low quality, bad quality, worst quality, blurry, disfigured, ugly, immature, cartoon, painting"
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible generation.")
parser.add_argument(
"--output_dir",
type=str,
default="output/",
help="The output directory where the results will be written.",
)
args = parser.parse_args()
return args
def get_controlnet(args):
if args.use_english:
sd_model_name = 'runwayml/stable-diffusion-v1-5'
controlnet_model_name = 'm3face/FaceControlNet'
if args.condition == 'mask':
controlnet_revision = 'segmentation-english'
elif args.condition == 'landmark':
controlnet_revision = 'landmark-english'
controlnet = ControlNetModel.from_pretrained(controlnet_model_name, use_safetensors=True, revision=controlnet_revision)
pipeline = StableDiffusionControlNetPipeline.from_pretrained(
sd_model_name, controlnet=controlnet, use_safetensors=True, safety_checker=None
).to("cuda")
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
pipeline.enable_model_cpu_offload()
else:
sd_model_name = 'BAAI/AltDiffusion-m18'
controlnet_model_name = 'm3face/FaceControlNet'
if args.condition == 'mask':
controlnet_revision = 'segmentation-mlin'
elif args.condition == 'landmark':
controlnet_revision = 'landmark-mlin'
vae = AutoencoderKL.from_pretrained(sd_model_name, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(sd_model_name, subfolder="unet")
tokenizer = AutoTokenizer.from_pretrained(sd_model_name, subfolder="tokenizer", use_fast=False)
text_encoder = RobertaSeriesModelWithTransformation.from_pretrained(sd_model_name, subfolder="text_encoder")
controlnet = ControlNetModel.from_pretrained(controlnet_model_name, revision=controlnet_revision)
scheduler = PNDMScheduler.from_pretrained(
sd_model_name,
subfolder='scheduler',
)
scheduler = UniPCMultistepScheduler.from_config(scheduler.config)
feature_extractor = CLIPFeatureExtractor.from_pretrained(
sd_model_name,
subfolder='feature_extractor',
)
pipeline = StableDiffusionControlNetPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
controlnet=controlnet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=feature_extractor,
).to('cuda')
return pipeline
def get_muse(args):
muse_model_name = 'm3face/FaceConditioning'
if args.condition == 'mask':
muse_revision = 'segmentation'
elif args.condition == 'landmark':
muse_revision = 'landmark'
scheduler = AmusedScheduler.from_pretrained(muse_model_name, revision=muse_revision, subfolder='scheduler')
vqvae = VQModel.from_pretrained(muse_model_name, revision=muse_revision, subfolder='vqvae')
uvit2 = UVit2DModel.from_pretrained(muse_model_name, revision=muse_revision, subfolder='transformer')
text_encoder = MultilingualCLIP.from_pretrained(muse_model_name, revision=muse_revision, subfolder='text_encoder')
tokenizer = AutoTokenizer.from_pretrained(muse_model_name, revision=muse_revision, subfolder='tokenizer')
pipeline = AmusedPipeline(
vqvae=vqvae,
tokenizer=tokenizer,
text_encoder=text_encoder,
transformer=uvit2,
scheduler=scheduler
).to("cuda")
return pipeline
if __name__ == '__main__':
args = parse_args()
# ========== set up face generation pipeline ==========
controlnet = get_controlnet(args)
# ========== set output directory ==========
os.makedirs(args.output_dir, exist_ok=True)
# ========== set random seed ==========
if args.seed is None:
generator = None
else:
generator = torch.Generator().manual_seed(args.seed)
# ========== generation ==========
id = int(time.time())
if args.condition_path:
condition = load_image(args.condition_path).resize((512, 512))
else:
# generate condition
muse = get_muse(args)
if args.condition == 'mask':
muse_added_prompt = 'Generate face segmentation | '
elif args.condition == 'landmark':
muse_added_prompt = 'Generate face landmark | '
muse_prompt = muse_added_prompt + args.prompt
condition = muse(muse_prompt, num_inference_steps=256).images[0].resize((512, 512))
if args.save_condition:
condition.save(f'{args.output_dir}/{id}_condition.png')
latents = torch.randn((args.num_samples, 4, 64, 64), generator=generator)
prompt = f'{args.prompt}, {args.additional_prompt}' if args.prompt else args.additional_prompt
images = controlnet(prompt, image=condition, num_inference_steps=args.num_inference_steps, negative_prompt=args.negative_prompt,
generator=generator, latents=latents, num_images_per_prompt=args.num_samples).images
for i, image in enumerate(images):
image.save(f'{args.output_dir}/{id}_{i}.png')
|