M3Face / edit.py
m3face's picture
Fixing links
8a7587e
import os
import argparse
from tqdm.auto import tqdm
from packaging import version
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torchvision import transforms
from diffusers import (
AutoencoderKL,
ControlNetModel,
DDPMScheduler,
StableDiffusionControlNetPipeline,
UNet2DConditionModel,
UniPCMultistepScheduler,
PNDMScheduler,
AmusedInpaintPipeline, AmusedScheduler, VQModel, UVit2DModel
)
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils import load_image
from transformers import AutoTokenizer, CLIPFeatureExtractor, PretrainedConfig
from PIL import Image
from utils.mclip import *
def parse_args():
parser = argparse.ArgumentParser(description="Edit images with M3Face.")
parser.add_argument(
"--prompt",
type=str,
default="This attractive woman has narrow eyes, rosy cheeks, and wears heavy makeup.",
help="The input text prompt for image generation."
)
parser.add_argument(
"--condition",
type=str,
default="mask",
choices=["mask", "landmark"],
help="Use segmentation mask or facial landmarks for image generation."
)
parser.add_argument(
"--image_path",
type=str,
default=None,
help="Path to the input image."
)
parser.add_argument(
"--condition_path",
type=str,
default=None,
help="Path to the original mask/landmark image."
)
parser.add_argument(
"--edit_condition_path",
type=str,
default=None,
help="Path to the target mask/landmark image."
)
parser.add_argument(
"--output_dir",
type=str,
default='output/',
help="The output directory where the results will be written.",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible generation.")
parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
parser.add_argument("--edit_condition", action="store_true")
parser.add_argument("--load_unet_from_local", action="store_true")
parser.add_argument("--save_unet", action="store_true")
parser.add_argument("--unet_local_path", type=str, default=None)
parser.add_argument("--load_finetune_from_local", action="store_true")
parser.add_argument("--finetune_path", type=str, default=None)
parser.add_argument("--use_english", action="store_true", help="Use the English models.")
parser.add_argument("--embedding_optimize_it", type=int, default=500)
parser.add_argument("--model_finetune_it", type=int, default=1000)
parser.add_argument("--alpha", nargs="+", type=float, default=[0.8, 0.9, 1, 1.1])
parser.add_argument("--num_inference_steps", nargs="+", type=int, default=[20, 40, 50])
parser.add_argument("--unet_layer", type=str, default="2and3",
help="Which UNet layers in the SD to finetune.")
args = parser.parse_args()
return args
def get_muse(args):
muse_model_name = 'm3face/FaceConditioning'
if args.condition == 'mask':
muse_revision = 'segmentation'
elif args.condition == 'landmark':
muse_revision = 'landmark'
scheduler = AmusedScheduler.from_pretrained(muse_model_name, revision=muse_revision, subfolder='scheduler')
vqvae = VQModel.from_pretrained(muse_model_name, revision=muse_revision, subfolder='vqvae')
uvit2 = UVit2DModel.from_pretrained(muse_model_name, revision=muse_revision, subfolder='transformer')
text_encoder = MultilingualCLIP.from_pretrained(muse_model_name, revision=muse_revision, subfolder='text_encoder')
tokenizer = AutoTokenizer.from_pretrained(muse_model_name, revision=muse_revision, subfolder='tokenizer')
pipeline = AmusedInpaintPipeline(
vqvae=vqvae,
tokenizer=tokenizer,
text_encoder=text_encoder,
transformer=uvit2,
scheduler=scheduler
).to("cuda")
return pipeline
def import_model_class_from_model_name(sd_model_name):
text_encoder_config = PretrainedConfig.from_pretrained(
sd_model_name,
subfolder="text_encoder",
)
model_class = text_encoder_config.architectures[0]
if model_class == "CLIPTextModel":
from transformers import CLIPTextModel
return CLIPTextModel
elif model_class == "RobertaSeriesModelWithTransformation":
from diffusers.pipelines.deprecated.alt_diffusion import RobertaSeriesModelWithTransformation
return RobertaSeriesModelWithTransformation
else:
raise ValueError(f"{model_class} is not supported.")
def preprocess(image, condition, prompt, tokenizer):
image_transforms = transforms.Compose(
[
transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(512),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
condition_transforms = transforms.Compose(
[
transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(512),
transforms.ToTensor(),
]
)
image = image_transforms(image)
condition = condition_transforms(condition)
inputs = tokenizer(
[prompt], max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
)
return image, condition, inputs.input_ids, inputs.attention_mask
def main(args):
if args.use_english:
sd_model_name = 'runwayml/stable-diffusion-v1-5'
controlnet_model_name = 'm3face/FaceControlNet'
if args.condition == 'mask':
controlnet_revision = 'segmentation-english'
elif args.condition == 'landmark':
controlnet_revision = 'landmark-english'
else:
sd_model_name = 'BAAI/AltDiffusion-m18'
controlnet_model_name = 'm3face/FaceControlNet'
if args.condition == 'mask':
controlnet_revision = 'segmentation-mlin'
elif args.condition == 'landmark':
controlnet_revision = 'landmark-mlin'
# ========== set up models ==========
vae = AutoencoderKL.from_pretrained(sd_model_name, subfolder="vae")
tokenizer = AutoTokenizer.from_pretrained(sd_model_name, subfolder="tokenizer", use_fast=False)
text_encoder_cls = import_model_class_from_model_name(sd_model_name)
text_encoder = text_encoder_cls.from_pretrained(sd_model_name, subfolder="text_encoder")
noise_scheduler = DDPMScheduler.from_pretrained(sd_model_name, subfolder="scheduler")
if args.load_unet_from_local:
unet = UNet2DConditionModel.from_pretrained(args.unet_local_path)
else:
unet = UNet2DConditionModel.from_pretrained(sd_model_name, subfolder="unet")
controlnet = ControlNetModel.from_pretrained(controlnet_model_name, revision=controlnet_revision)
if args.edit_condition:
muse = get_muse(args)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
controlnet.requires_grad_(False)
unet.requires_grad_(False)
vae.eval()
text_encoder.eval()
controlnet.eval()
unet.eval()
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
import xformers
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
print(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
controlnet.enable_xformers_memory_efficient_attention()
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
# ========== select params to optimize ==========
params = []
for name, param in unet.named_parameters():
if(name.startswith('up_blocks')):
params.append(param)
if args.unet_layer == 'only1': # 116 layers
params_to_optimize = [
{'params': params[38:154]},
]
elif args.unet_layer == 'only2': # 116 layers
params_to_optimize = [
{'params': params[154:270]},
]
elif args.unet_layer == 'only3': # 114 layers
params_to_optimize = [
{'params': params[270:]},
]
elif args.unet_layer == '1and2': # 232 layers
params_to_optimize = [
{'params': params[38:270]},
]
elif args.unet_layer == '2and3': # 230 layers
params_to_optimize = [
{'params': params[154:]},
]
elif args.unet_layer == 'all': # all layers
params_to_optimize = [
{'params': params},
]
image = Image.open(args.image_path).convert('RGB')
condition = Image.open(args.condition_path).convert('RGB')
image, condition, input_ids, attention_mask = preprocess(image, condition, args.prompt, tokenizer)
# Move to device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
vae.to(device, dtype=torch.float32)
unet.to(device, dtype=torch.float32)
text_encoder.to(device, dtype=torch.float32)
controlnet.to(device)
image = image.to(device).unsqueeze(0)
condition = condition.to(device).unsqueeze(0)
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
# ========== imagic ==========
if args.load_finetune_from_local:
print('Loading embeddings from local ...')
orig_emb = torch.load(os.path.join(args.finetune_path, 'orig_emb.pt'))
emb = torch.load(os.path.join(args.finetune_path, 'emb.pt'))
else:
init_latent = vae.encode(image.to(dtype=torch.float32)).latent_dist.sample()
init_latent = init_latent * vae.config.scaling_factor
if not args.use_english:
orig_emb = text_encoder(input_ids, attention_mask=attention_mask)[0]
else:
orig_emb = text_encoder(input_ids)[0]
emb = orig_emb.clone()
torch.save(orig_emb, os.path.join(args.output_dir, 'orig_emb.pt'))
torch.save(emb, os.path.join(args.output_dir, 'emb.pt'))
# 1. Optimize the embedding
print('1. Optimize the embedding')
unet.eval()
emb.requires_grad = True
lr = 0.001
it = args.embedding_optimize_it # 500
opt = torch.optim.Adam([emb], lr=lr)
history = []
pbar = tqdm(
range(it),
initial=0,
desc="Optimize Steps",
)
global_step = 0
for i in pbar:
opt.zero_grad()
noise = torch.randn_like(init_latent)
bsz = init_latent.shape[0]
t_enc = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=init_latent.device)
t_enc = t_enc.long()
z = noise_scheduler.add_noise(init_latent, noise, t_enc)
controlnet_image = condition.to(dtype=torch.float32)
down_block_res_samples, mid_block_res_sample = controlnet(
z,
t_enc,
encoder_hidden_states=emb,
controlnet_cond=controlnet_image,
return_dict=False,
)
# Predict the noise residual
pred_noise = unet(
z,
t_enc,
encoder_hidden_states=emb,
down_block_additional_residuals=[
sample.to(dtype=torch.float32) for sample in down_block_res_samples
],
mid_block_additional_residual=mid_block_res_sample.to(dtype=torch.float32),
).sample
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(init_latent, noise, t_enc)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
loss = F.mse_loss(pred_noise.float(), target.float(), reduction="mean")
loss.backward()
global_step += 1
pbar.set_postfix({"loss": loss.item()})
history.append(loss.item())
opt.step()
opt.zero_grad()
# 2. Finetune the model
print('2. Finetune the model')
emb.requires_grad = False
unet.requires_grad_(True)
unet.train()
lr = 5e-5
it = args.model_finetune_it # 1000
opt = torch.optim.Adam(params_to_optimize, lr=lr)
history = []
pbar = tqdm(
range(it),
initial=0,
desc="Finetune Steps",
)
global_step = 0
for i in pbar:
opt.zero_grad()
noise = torch.randn_like(init_latent)
bsz = init_latent.shape[0]
t_enc = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=init_latent.device)
t_enc = t_enc.long()
z = noise_scheduler.add_noise(init_latent, noise, t_enc)
controlnet_image = condition.to(dtype=torch.float32)
down_block_res_samples, mid_block_res_sample = controlnet(
z,
t_enc,
encoder_hidden_states=emb,
controlnet_cond=controlnet_image,
return_dict=False,
)
# Predict the noise residual
pred_noise = unet(
z,
t_enc,
encoder_hidden_states=emb,
down_block_additional_residuals=[
sample.to(dtype=torch.float32) for sample in down_block_res_samples
],
mid_block_additional_residual=mid_block_res_sample.to(dtype=torch.float32),
).sample
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(init_latent, noise, t_enc)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
loss = F.mse_loss(pred_noise.float(), target.float(), reduction="mean")
loss.backward()
global_step += 1
pbar.set_postfix({"loss": loss.item()})
history.append(loss.item())
opt.step()
opt.zero_grad()
# 3. Generate Images
print("3. Generating images... ")
unet.eval()
controlnet.eval()
if args.edit_condition_path is None:
edit_condition = load_image(args.condition_path)
else:
edit_condition = load_image(args.edit_condition_path)
if args.edit_condition:
edit_mask = Image.new("L", (256, 256), 0)
for i in range(256):
for j in range(256):
if 40 < i < 220 and 20 < j < 256:
edit_mask.putpixel((i, j), 256)
if args.condition == 'mask':
condition = 'segmentation'
elif args.condition == 'landmark':
condition = 'landmark'
edit_prompt = f"Generate face {condition} | " + args.prompt
input_image = edit_condition.resize((256, 256)).convert("RGB")
edit_condition = muse(edit_prompt, input_image, edit_mask, num_inference_steps=30).images[0].resize((512, 512))
edit_condition.save(f'{args.output_dir}/edited_condition.png')
# remove muse and empty cache
del muse
torch.cuda.empty_cache()
if sd_model_name.startswith('BAAI'):
scheduler = PNDMScheduler.from_pretrained(
sd_model_name,
subfolder='scheduler',
)
scheduler = UniPCMultistepScheduler.from_config(scheduler.config)
feature_extractor = CLIPFeatureExtractor.from_pretrained(
sd_model_name,
subfolder='feature_extractor',
)
pipeline = StableDiffusionControlNetPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
controlnet=controlnet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=feature_extractor
)
else:
pipeline = StableDiffusionControlNetPipeline.from_pretrained(
sd_model_name,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
controlnet=controlnet,
safety_checker=None,
)
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
pipeline = pipeline.to(device)
pipeline.set_progress_bar_config(disable=True)
if args.enable_xformers_memory_efficient_attention:
pipeline.enable_xformers_memory_efficient_attention()
if args.seed is None:
generator = None
else:
generator = torch.Generator(device=device).manual_seed(args.seed)
with torch.autocast("cuda"):
image = pipeline(
image=edit_condition, prompt_embeds=emb, num_inference_steps=20, generator=generator
).images[0]
image.save(f'{args.output_dir}/reconstruct.png')
# Interpolate the embedding
for num_inference_steps in args.num_inference_steps:
for alpha in args.alpha:
new_emb = alpha * orig_emb + (1 - alpha) * emb
with torch.autocast("cuda"):
image = pipeline(
image=edit_condition, prompt_embeds=new_emb, num_inference_steps=num_inference_steps, generator=generator
).images[0]
image.save(f'{args.output_dir}/image_{num_inference_steps}_{alpha}.png')
if args.save_unet:
print('Saving the unet model...')
unet.save_pretrained(f'{args.output_dir}/unet')
if __name__ == '__main__':
args = parse_args()
main(args)